Rewriting Go with AST transformation

At work, my team decided to switch our codebase from pkg/errors to Go 1.13’s native error wrapping. We used to wrap our errors like this:

var errBadStuff := errors.New("something happened")
...
err := errors.Wrapf(errBadStuff, "some context '%s'", label)

Starting in 1.13, we can wrap errors directly using fmt.Errorf and the %w code:

err := fmt.Errorf("some context '%s': %w", label, errBadStuff)

Rather than change hundreds of wrappings across tens of thousands of lines by hand or with kludgy string replacement, I wrote a program to rewrite our source files using Go’s AST instead.

Parsing, walking and rewriting an AST

The central function of the program takes in old source, parses it, transforms it and returns the new code. Parsing uses Go’s built-in packages token and parser to produce an AST for a file:

func Rewrite(filename string, oldSource []byte) ([]byte, error) {
    fset := token.NewFileSet()
    oldAST, err := parser.ParseFile(fset, filename, oldSource, parser.ParseComments)
    if err != nil {
        return nil, fmt.Errorf("error parsing %s: %w", filename, err)
    }
    ...
}

Confusingly, the parser.ParseComments argument to ParseFile means “include comments in the parse, not just code”, which I need to preserve all the input for rewriting.

I used Fatih Arslan’s astrewrite library for the AST manipulation, as it provides a Walk function that (unlike Go’s built-in AST walker), lets me return a modified node to replace in the AST structure.

func Rewrite(filename string, oldSource []byte) ([]byte, error) {
    ...
    newAST := astrewrite.Walk(oldAST, visitor)
    ...
}

Given the rewritten AST, the output is rendered using the same library used by gofmt, called format:

func Rewrite(filename string, oldSource []byte) ([]byte, error) {
    ...
    buf := &bytes.Buffer{}
    err = format.Node(buf, fset, newAST)
    if err != nil {
        return nil, fmt.Errorf("error formatting new code: %w", err)
    }
    return buf.Bytes(), nil
}

That’s the high-level flow. All the interesting details are in the visitor function that finds the nodes of interest and rewrites them.

Inspecting the AST

Finding nodes in the AST and changing them requires understanding what the AST looks like. My first attempt just dumped out nodes to stdout in visitor, but I later found an online Go AST Viewer that makes it easy to see the AST. By giving it a small bit of code, I can see how that looks in AST-form.

For example, here is errors.Wrapf("foo '%s', err) in its AST form, which appears as an *ast.CallExpr:

X: *ast.CallExpr {
   Fun: *ast.SelectorExpr {
      X: *ast.Ident {
         NamePos: 10:2
         Name: "errors"
      }
      Sel: *ast.Ident {
         NamePos: 10:9
         Name: "Wrapf"
      }
   }
   Lparen: 10:14
   Args: []ast.Expr (len = 2) {
      0: *ast.BasicLit {
         ValuePos: 10:15
         Kind: STRING
         Value: "\"foo '%s'\""
      }
      1: *ast.Ident {
         NamePos: 10:27
         Name: "err"
         Obj: *(obj @ 54)
      }
   }
   Ellipsis: -
   Rparen: 10:30
}

The important parts are the Fun and Args fields, which give the command “name” and arguments, respectively. The name of the call is in a SelectorExpr, split into X and Sel, but the AST dump is slightly deceptive because SelectorExpr actually has an interface type for X:

type SelectorExpr struct {
    X   Expr   // expression
    Sel *Ident // field selector
}

So when checking if a CallExpr.SelectorExpr.X is errors.Wrapf, I must check if the concrete type is Ident.

Here’s a helper function I wrote to extract the function call name out of a CallExpr:

func getCallExprLiteral(c *ast.CallExpr) string {
    s, ok := c.Fun.(*ast.SelectorExpr)
    if !ok {
        return ""
    }

    i, ok := s.X.(*ast.Ident)
    if !ok {
        return ""
    }

    return i.Name + "." + s.Sel.Name
}

If the CallExpr doesn’t give us a literal function name, the helper returns an empty string to signal that no literal name was found.

Finding a node to rewrite

Knowing the AST structure and having a helper to extract a function name from the AST is all that’s needed for the “find a node to rewrite” part of the visitor function. If it’s visiting a CallExpr, the visitor just delegates to a handler for that type.

func visitor(n ast.Node) (ast.Node, bool) {
    switch v := n.(type) {
    case *ast.CallExpr:
        return handleCallExpr(v)
    default:
        return n, true
    }
}

The handler only has something to do if the function name matches the names I’m looking for from pkg/errors.

func handleCallExpr(ce *ast.CallExpr) (ast.Node, bool) {
    name := getCallExprLiteral(ce)
    switch name {
    case "errors.Wrap":
        return rewriteWrap(ce), true
    case "errors.Wrapf":
        return rewriteWrap(ce), true
    default:
        return ce, true
    }
}

If it does match, then the program rewrites it. Otherwise, nodes are returned unchanged from visitor.

Rewriting an AST node

Here are the signatures for the pkg/errors functions to rewrite:

func Wrap(err error, message string) error
func Wrapf(err error, format string, args ...interface{}) error

Three things need to happen to rewrite these to fmt.Errorf. First, the error argument needs to rotate from the beginning to the end of the argument list. Second, the message or format strings need to have ": %w" appended to them. Finally, the function name has to change.

Everything needed for the first two changes is in the *ast.CallExpr.Args field. For the rotation, I create a new argument list, copy everything but the original first argument, then append the original first argument.

func rewriteWrap(ce *ast.CallExpr) *ast.CallExpr {
    // Rotate err to the end of a new args list
    newArgs := make([]ast.Expr, len(ce.Args)-1)
    copy(newArgs, ce.Args[1:])
    newArgs = append(newArgs, ce.Args[0])
    ...
}

For appending ": %w", I have to account for whether the message/format is a literal string or if it’s another expression, such a global variable or function that returns a string.

If it’s a literal, then the literal value itself needs to be amended. Here’s how a format string looks as an *ast.BasicLit type from the CallExpr example above:

*ast.BasicLit {
     ValuePos: 10:15
     Kind: STRING
     Value: "\"foo '%s'\""
}

Note that the Value field includes the quotation marks. To append, I can remove the trailing " and add : %w".

If the value isn’t a BasicLit, then given some expression X, I need to replace that with the expression X + ": %w". In AST terms, that’s an *ast.BinaryExpr that adds the original argument and a new BasicLit.

(How did I know that? I guessed, based on my experience with other language AST’s, but if I didn’t know that, I could put an example into the AST Viewer and see what it shows.)

Here’s how that looks in code. If the first argument is a basic literal, amend the Value, otherwise, replace the first argument with a new BinaryExpr:

func rewriteWrap(ce *ast.CallExpr) *ast.CallExpr {
    ...
    // If the format string is a literal, we can rewrite it:
    //     "......" -> "......: %w"
    // Otherwise, we replace it with a binary op to add the wrap code:
    //     SomeNonLiteral -> SomeNonLiteral + ": %w"
    fmtStr, ok := newArgs[0].(*ast.BasicLit)
    if ok {
        // Strip trailing `"` and append wrap code and new trailing `"`
        fmtStr.Value = fmtStr.Value[:len(fmtStr.Value)-1] + `: %w"`
    } else {
        binOp := &ast.BinaryExpr{
            X:  newArgs[0],
            Op: token.ADD,
            Y:  &ast.BasicLit{Kind: token.STRING, Value: `": %w"`},
        }
        newArgs[0] = binOp
    }
    ...
}

The last part of the rewrite is to return a new CallExpr with fmt.Errof as the function name and with the new argument list.

func rewriteWrap(ce *ast.CallExpr) *ast.CallExpr {
    ...
    return newErrorfExpr(newArgs)
}

func newErrorfExpr(args []ast.Expr) *ast.CallExpr {
    return &ast.CallExpr{
        Fun: &ast.SelectorExpr{
            X:   &ast.Ident{Name: "fmt"},
            Sel: &ast.Ident{Name: "Errorf"},
        },
        Args: args,
    }
}

That’s it! The visitor function will now rewrite errors.Wrap and errors.Wrapf to fmt.Errorf.

Applying it to the codebase

The main function I wrote only handles individual files. While I didn’t show it in this article, it also replaces the github.com/pkg/errors import statement with errors so that errors.New uses the core errors library. But it’s possible that errors isn’t needed anymore, only fmt, so in the end, I used a loop to rewrap the errors with my program and then wash the results through goimports:

for f in $(find . -iname "*.go"); do go-rewrap-errors -w $f; goimports -w $f; done

This did exactly what I wanted – rewrite hundreds of wrapping functions across thousands of lines of code, without messing with my editor or string replacement. And it’s a reusable tool for other codebases at my company (or that you can use with the link at the end of this article).

I did have to go through and fix up a couple things by hand:

  • Our custom error types had a Cause method to unwrap the inner error. I changed those to Unwrap, which is what built-in errors expects. There were so few of these that it was faster to edit by hand than do it via AST transformation.
  • In a few places, we were using pkg/errors.Cause to unwrap errors. I changed those to use built-in errors.As or errors.Is.

We weren’t using other feature of pkg/errors so there was nothing else to do.

While the Go AST seems intimidating at first, it’s an extremely powerful tool for code transformation. With some visualization to see what before and after code looks like, and some trial and error, it made a hard job updating our codebase seem very easy.

The full code for this article may be found at github.com/xdg-go/go-rewrap-errors