kilo/vendor/honnef.co/go/tools/quickfix/lint.go

955 lines
26 KiB
Go
Raw Normal View History

package quickfix
import (
"fmt"
"go/ast"
"go/constant"
"go/token"
"go/types"
"strings"
"honnef.co/go/tools/analysis/code"
"honnef.co/go/tools/analysis/edit"
"honnef.co/go/tools/analysis/report"
"honnef.co/go/tools/go/ast/astutil"
"honnef.co/go/tools/go/types/typeutil"
"honnef.co/go/tools/pattern"
"golang.org/x/tools/go/analysis"
)
func negateDeMorgan(expr ast.Expr, recursive bool) ast.Expr {
switch expr := expr.(type) {
case *ast.BinaryExpr:
var out ast.BinaryExpr
switch expr.Op {
case token.EQL:
out.X = expr.X
out.Op = token.NEQ
out.Y = expr.Y
case token.LSS:
out.X = expr.X
out.Op = token.GEQ
out.Y = expr.Y
case token.GTR:
out.X = expr.X
out.Op = token.LEQ
out.Y = expr.Y
case token.NEQ:
out.X = expr.X
out.Op = token.EQL
out.Y = expr.Y
case token.LEQ:
out.X = expr.X
out.Op = token.GTR
out.Y = expr.Y
case token.GEQ:
out.X = expr.X
out.Op = token.LSS
out.Y = expr.Y
case token.LAND:
out.X = negateDeMorgan(expr.X, recursive)
out.Op = token.LOR
out.Y = negateDeMorgan(expr.Y, recursive)
case token.LOR:
out.X = negateDeMorgan(expr.X, recursive)
out.Op = token.LAND
out.Y = negateDeMorgan(expr.Y, recursive)
}
return &out
case *ast.ParenExpr:
if recursive {
return &ast.ParenExpr{
X: negateDeMorgan(expr.X, recursive),
}
} else {
return &ast.UnaryExpr{
Op: token.NOT,
X: expr,
}
}
case *ast.UnaryExpr:
if expr.Op == token.NOT {
return expr.X
} else {
return &ast.UnaryExpr{
Op: token.NOT,
X: expr,
}
}
default:
return &ast.UnaryExpr{
Op: token.NOT,
X: expr,
}
}
}
func simplifyParentheses(node ast.Expr) ast.Expr {
var changed bool
// XXX accept list of ops to operate on
// XXX copy AST node, don't modify in place
post := func(c *astutil.Cursor) bool {
out := c.Node()
if paren, ok := c.Node().(*ast.ParenExpr); ok {
out = paren.X
}
if binop, ok := out.(*ast.BinaryExpr); ok {
if right, ok := binop.Y.(*ast.BinaryExpr); ok && binop.Op == right.Op {
// XXX also check that Op is associative
root := binop
pivot := root.Y.(*ast.BinaryExpr)
root.Y = pivot.X
pivot.X = root
root = pivot
out = root
}
}
if out != c.Node() {
changed = true
c.Replace(out)
}
return true
}
for changed = true; changed; {
changed = false
node = astutil.Apply(node, nil, post).(ast.Expr)
}
return node
}
var demorganQ = pattern.MustParse(`(UnaryExpr "!" expr@(BinaryExpr _ _ _))`)
func CheckDeMorgan(pass *analysis.Pass) (interface{}, error) {
// TODO(dh): support going in the other direction, e.g. turning `!a && !b && !c` into `!(a || b || c)`
// hasFloats reports whether any subexpression is of type float.
hasFloats := func(expr ast.Expr) bool {
found := false
ast.Inspect(expr, func(node ast.Node) bool {
if expr, ok := node.(ast.Expr); ok {
if basic, ok := pass.TypesInfo.TypeOf(expr).Underlying().(*types.Basic); ok {
if (basic.Info() & types.IsFloat) != 0 {
found = true
return false
}
}
}
return true
})
return found
}
fn := func(node ast.Node, stack []ast.Node) {
matcher, ok := code.Match(pass, demorganQ, node)
if !ok {
return
}
expr := matcher.State["expr"].(ast.Expr)
// be extremely conservative when it comes to floats
if hasFloats(expr) {
return
}
n := negateDeMorgan(expr, false)
nr := negateDeMorgan(expr, true)
nc, ok := astutil.CopyExpr(n)
if !ok {
return
}
ns := simplifyParentheses(nc)
nrc, ok := astutil.CopyExpr(nr)
if !ok {
return
}
nrs := simplifyParentheses(nrc)
var bn, bnr, bns, bnrs string
switch parent := stack[len(stack)-2]; parent.(type) {
case *ast.BinaryExpr, *ast.IfStmt, *ast.ForStmt, *ast.SwitchStmt:
// Always add parentheses for if, for and switch. If
// they're unnecessary, go/printer will strip them when
// the whole file gets formatted.
bn = report.Render(pass, &ast.ParenExpr{X: n})
bnr = report.Render(pass, &ast.ParenExpr{X: nr})
bns = report.Render(pass, &ast.ParenExpr{X: ns})
bnrs = report.Render(pass, &ast.ParenExpr{X: nrs})
default:
// TODO are there other types where we don't want to strip parentheses?
bn = report.Render(pass, n)
bnr = report.Render(pass, nr)
bns = report.Render(pass, ns)
bnrs = report.Render(pass, nrs)
}
// Note: we cannot compare the ASTs directly, because
// simplifyParentheses might have rebalanced trees without
// affecting the rendered form.
var fixes []analysis.SuggestedFix
fixes = append(fixes, edit.Fix("Apply De Morgan's law", edit.ReplaceWithString(node, bn)))
if bn != bns {
fixes = append(fixes, edit.Fix("Apply De Morgan's law & simplify", edit.ReplaceWithString(node, bns)))
}
if bn != bnr {
fixes = append(fixes, edit.Fix("Apply De Morgan's law recursively", edit.ReplaceWithString(node, bnr)))
if bnr != bnrs {
fixes = append(fixes, edit.Fix("Apply De Morgan's law recursively & simplify", edit.ReplaceWithString(node, bnrs)))
}
}
report.Report(pass, node, "could apply De Morgan's law", report.Fixes(fixes...))
}
code.PreorderStack(pass, fn, (*ast.UnaryExpr)(nil))
return nil, nil
}
func findSwitchPairs(pass *analysis.Pass, expr ast.Expr, pairs *[]*ast.BinaryExpr) (OUT bool) {
binexpr, ok := astutil.Unparen(expr).(*ast.BinaryExpr)
if !ok {
return false
}
switch binexpr.Op {
case token.EQL:
if code.MayHaveSideEffects(pass, binexpr.X, nil) || code.MayHaveSideEffects(pass, binexpr.Y, nil) {
return false
}
// syntactic identity should suffice. we do not allow side
// effects in the case clauses, so there should be no way for
// values to change.
if len(*pairs) > 0 && !astutil.Equal(binexpr.X, (*pairs)[0].X) {
return false
}
*pairs = append(*pairs, binexpr)
return true
case token.LOR:
return findSwitchPairs(pass, binexpr.X, pairs) && findSwitchPairs(pass, binexpr.Y, pairs)
default:
return false
}
}
func CheckTaglessSwitch(pass *analysis.Pass) (interface{}, error) {
fn := func(node ast.Node) {
swtch := node.(*ast.SwitchStmt)
if swtch.Tag != nil || len(swtch.Body.List) == 0 {
return
}
pairs := make([][]*ast.BinaryExpr, len(swtch.Body.List))
for i, stmt := range swtch.Body.List {
stmt := stmt.(*ast.CaseClause)
for _, cond := range stmt.List {
if !findSwitchPairs(pass, cond, &pairs[i]) {
return
}
}
}
var x ast.Expr
for _, pair := range pairs {
if len(pair) == 0 {
continue
}
if x == nil {
x = pair[0].X
} else {
if !astutil.Equal(x, pair[0].X) {
return
}
}
}
if x == nil {
// the switch only has a default case
if len(pairs) > 1 {
panic("found more than one case clause with no pairs")
}
return
}
edits := make([]analysis.TextEdit, 0, len(swtch.Body.List)+1)
for i, stmt := range swtch.Body.List {
stmt := stmt.(*ast.CaseClause)
if stmt.List == nil {
continue
}
var values []string
for _, binexpr := range pairs[i] {
y := binexpr.Y
if p, ok := y.(*ast.ParenExpr); ok {
y = p.X
}
values = append(values, report.Render(pass, y))
}
edits = append(edits, edit.ReplaceWithString(edit.Range{stmt.List[0].Pos(), stmt.Colon}, strings.Join(values, ", ")))
}
pos := swtch.Switch + token.Pos(len("switch"))
edits = append(edits, edit.ReplaceWithString(edit.Range{pos, pos}, " "+report.Render(pass, x)))
report.Report(pass, swtch, fmt.Sprintf("could use tagged switch on %s", report.Render(pass, x)),
report.Fixes(edit.Fix("Replace with tagged switch", edits...)))
}
code.Preorder(pass, fn, (*ast.SwitchStmt)(nil))
return nil, nil
}
func CheckIfElseToSwitch(pass *analysis.Pass) (interface{}, error) {
fn := func(node ast.Node, stack []ast.Node) {
if _, ok := stack[len(stack)-2].(*ast.IfStmt); ok {
// this if statement is part of an if-else chain
return
}
ifstmt := node.(*ast.IfStmt)
m := map[ast.Expr][]*ast.BinaryExpr{}
for item := ifstmt; item != nil; {
if item.Init != nil {
return
}
if item.Body == nil {
return
}
skip := false
ast.Inspect(item.Body, func(node ast.Node) bool {
if branch, ok := node.(*ast.BranchStmt); ok && branch.Tok != token.GOTO {
skip = true
return false
}
return true
})
if skip {
return
}
var pairs []*ast.BinaryExpr
if !findSwitchPairs(pass, item.Cond, &pairs) {
return
}
m[item.Cond] = pairs
switch els := item.Else.(type) {
case *ast.IfStmt:
item = els
case *ast.BlockStmt, nil:
item = nil
default:
panic(fmt.Sprintf("unreachable: %T", els))
}
}
var x ast.Expr
for _, pair := range m {
if len(pair) == 0 {
continue
}
if x == nil {
x = pair[0].X
} else {
if !astutil.Equal(x, pair[0].X) {
return
}
}
}
if x == nil {
// shouldn't happen
return
}
// We require at least two 'if' to make this suggestion, to
// avoid clutter in the editor.
if len(m) < 2 {
return
}
var edits []analysis.TextEdit
for item := ifstmt; item != nil; {
var end token.Pos
if item.Else != nil {
end = item.Else.Pos()
} else {
// delete up to but not including the closing brace.
end = item.Body.Rbrace
}
var conds []string
for _, cond := range m[item.Cond] {
y := cond.Y
if p, ok := y.(*ast.ParenExpr); ok {
y = p.X
}
conds = append(conds, report.Render(pass, y))
}
sconds := strings.Join(conds, ", ")
edits = append(edits,
edit.ReplaceWithString(edit.Range{item.If, item.Body.Lbrace + 1}, "case "+sconds+":"),
edit.Delete(edit.Range{item.Body.Rbrace, end}))
switch els := item.Else.(type) {
case *ast.IfStmt:
item = els
case *ast.BlockStmt:
edits = append(edits, edit.ReplaceWithString(edit.Range{els.Lbrace, els.Lbrace + 1}, "default:"))
item = nil
case nil:
item = nil
default:
panic(fmt.Sprintf("unreachable: %T", els))
}
}
// FIXME this forces the first case to begin in column 0. try to fix the indentation
edits = append(edits, edit.ReplaceWithString(edit.Range{ifstmt.If, ifstmt.If}, fmt.Sprintf("switch %s {\n", report.Render(pass, x))))
report.Report(pass, ifstmt, fmt.Sprintf("could use tagged switch on %s", report.Render(pass, x)),
report.Fixes(edit.Fix("Replace with tagged switch", edits...)),
report.ShortRange())
}
code.PreorderStack(pass, fn, (*ast.IfStmt)(nil))
return nil, nil
}
var stringsReplaceAllQ = pattern.MustParse(`(Or
(CallExpr fn@(Function "strings.Replace") [_ _ _ lit@(IntegerLiteral "-1")])
(CallExpr fn@(Function "strings.SplitN") [_ _ lit@(IntegerLiteral "-1")])
(CallExpr fn@(Function "strings.SplitAfterN") [_ _ lit@(IntegerLiteral "-1")])
(CallExpr fn@(Function "bytes.Replace") [_ _ _ lit@(IntegerLiteral "-1")])
(CallExpr fn@(Function "bytes.SplitN") [_ _ lit@(IntegerLiteral "-1")])
(CallExpr fn@(Function "bytes.SplitAfterN") [_ _ lit@(IntegerLiteral "-1")]))`)
func CheckStringsReplaceAll(pass *analysis.Pass) (interface{}, error) {
// XXX respect minimum Go version
// FIXME(dh): create proper suggested fix for renamed import
fn := func(node ast.Node) {
matcher, ok := code.Match(pass, stringsReplaceAllQ, node)
if !ok {
return
}
var replacement string
switch typeutil.FuncName(matcher.State["fn"].(*types.Func)) {
case "strings.Replace":
replacement = "strings.ReplaceAll"
case "strings.SplitN":
replacement = "strings.Split"
case "strings.SplitAfterN":
replacement = "strings.SplitAfter"
case "bytes.Replace":
replacement = "bytes.ReplaceAll"
case "bytes.SplitN":
replacement = "bytes.Split"
case "bytes.SplitAfterN":
replacement = "bytes.SplitAfter"
default:
panic("unreachable")
}
call := node.(*ast.CallExpr)
report.Report(pass, call.Fun, fmt.Sprintf("could use %s instead", replacement),
report.Fixes(edit.Fix(fmt.Sprintf("Use %s instead", replacement),
edit.ReplaceWithString(call.Fun, replacement),
edit.Delete(matcher.State["lit"].(ast.Node)))))
}
code.Preorder(pass, fn, (*ast.CallExpr)(nil))
return nil, nil
}
var mathPowQ = pattern.MustParse(`(CallExpr (Function "math.Pow") [x (IntegerLiteral n)])`)
func CheckMathPow(pass *analysis.Pass) (interface{}, error) {
fn := func(node ast.Node) {
matcher, ok := code.Match(pass, mathPowQ, node)
if !ok {
return
}
x := matcher.State["x"].(ast.Expr)
if code.MayHaveSideEffects(pass, x, nil) {
return
}
n, ok := constant.Int64Val(constant.ToInt(matcher.State["n"].(types.TypeAndValue).Value))
if !ok {
return
}
needConversion := false
if T, ok := pass.TypesInfo.Types[x]; ok && T.Value != nil {
info := types.Info{
Types: map[ast.Expr]types.TypeAndValue{},
}
// determine if the constant expression would have type float64 if used on its own
if err := types.CheckExpr(pass.Fset, pass.Pkg, x.Pos(), x, &info); err != nil {
// This should not happen
return
}
if T, ok := info.Types[x].Type.(*types.Basic); ok {
if T.Kind() != types.UntypedFloat && T.Kind() != types.Float64 {
needConversion = true
}
} else {
needConversion = true
}
}
var replacement ast.Expr
switch n {
case 0:
replacement = &ast.BasicLit{
Kind: token.FLOAT,
Value: "1.0",
}
case 1:
replacement = x
case 2, 3:
r := &ast.BinaryExpr{
X: x,
Op: token.MUL,
Y: x,
}
for i := 3; i <= int(n); i++ {
r = &ast.BinaryExpr{
X: r,
Op: token.MUL,
Y: x,
}
}
rc, ok := astutil.CopyExpr(r)
if !ok {
return
}
replacement = simplifyParentheses(rc)
default:
return
}
if needConversion && n != 0 {
replacement = &ast.CallExpr{
Fun: &ast.Ident{Name: "float64"},
Args: []ast.Expr{replacement},
}
}
report.Report(pass, node, "could expand call to math.Pow",
report.Fixes(edit.Fix("Expand call to math.Pow", edit.ReplaceWithNode(pass.Fset, node, replacement))))
}
code.Preorder(pass, fn, (*ast.CallExpr)(nil))
return nil, nil
}
var checkForLoopIfBreak = pattern.MustParse(`(ForStmt nil nil nil if@(IfStmt nil cond (BranchStmt "BREAK" nil) nil):_)`)
func CheckForLoopIfBreak(pass *analysis.Pass) (interface{}, error) {
fn := func(node ast.Node) {
m, ok := code.Match(pass, checkForLoopIfBreak, node)
if !ok {
return
}
pos := node.Pos() + token.Pos(len("for"))
r := negateDeMorgan(m.State["cond"].(ast.Expr), false)
// FIXME(dh): we're leaving behind an empty line when we
// delete the old if statement. However, we can't just delete
// an additional character, in case there closing curly brace
// is followed by a comment, or Windows newlines.
report.Report(pass, m.State["if"].(ast.Node), "could lift into loop condition",
report.Fixes(edit.Fix("Lift into loop condition",
edit.ReplaceWithString(edit.Range{pos, pos}, " "+report.Render(pass, r)),
edit.Delete(m.State["if"].(ast.Node)))))
}
code.Preorder(pass, fn, (*ast.ForStmt)(nil))
return nil, nil
}
var checkConditionalAssignmentQ = pattern.MustParse(`(AssignStmt x@(Object _) ":=" assign@(Builtin b@(Or "true" "false")))`)
var checkConditionalAssignmentIfQ = pattern.MustParse(`(IfStmt nil cond [(AssignStmt x@(Object _) "=" (Builtin b@(Or "true" "false")))] nil)`)
func CheckConditionalAssignment(pass *analysis.Pass) (interface{}, error) {
fn := func(node ast.Node) {
var body *ast.BlockStmt
switch node := node.(type) {
case *ast.FuncDecl:
body = node.Body
case *ast.FuncLit:
body = node.Body
default:
panic("unreachable")
}
if body == nil {
return
}
stmts := body.List
if len(stmts) < 2 {
return
}
for i, first := range stmts[:len(stmts)-1] {
second := stmts[i+1]
m1, ok := code.Match(pass, checkConditionalAssignmentQ, first)
if !ok {
continue
}
m2, ok := code.Match(pass, checkConditionalAssignmentIfQ, second)
if !ok {
continue
}
if m1.State["x"] != m2.State["x"] {
continue
}
if m1.State["b"] == m2.State["b"] {
continue
}
v := m2.State["cond"].(ast.Expr)
if m1.State["b"] == "true" {
v = &ast.UnaryExpr{
Op: token.NOT,
X: v,
}
}
report.Report(pass, first, "could merge conditional assignment into variable declaration",
report.Fixes(edit.Fix("Merge conditional assignment into variable declaration",
edit.ReplaceWithNode(pass.Fset, m1.State["assign"].(ast.Node), v),
edit.Delete(second))))
}
}
code.Preorder(pass, fn, (*ast.FuncDecl)(nil), (*ast.FuncLit)(nil))
return nil, nil
}
func CheckExplicitEmbeddedSelector(pass *analysis.Pass) (interface{}, error) {
type Selector struct {
Node *ast.SelectorExpr
X ast.Expr
Fields []*ast.Ident
}
// extractSelectors extracts uninterrupted sequences of selector expressions.
// For example, for a.b.c().d.e[0].f.g three sequences will be returned: (X=a, X.b.c), (X=a.b.c(), X.d.e), and (X=a.b.c().d.e[0], X.f.g)
//
// It returns nil if the provided selector expression is not the root of a set of sequences.
// For example, for a.b.c, if node is b.c, no selectors will be returned.
extractSelectors := func(expr *ast.SelectorExpr) []Selector {
path, _ := astutil.PathEnclosingInterval(code.File(pass, expr), expr.Pos(), expr.Pos())
for i := len(path) - 1; i >= 0; i-- {
if el, ok := path[i].(*ast.SelectorExpr); ok {
if el != expr {
// this expression is a subset of the entire chain, don't look at it.
return nil
}
break
}
}
inChain := false
var out []Selector
for _, el := range path {
if expr, ok := el.(*ast.SelectorExpr); ok {
if !inChain {
inChain = true
out = append(out, Selector{X: expr.X})
}
sel := &out[len(out)-1]
sel.Fields = append(sel.Fields, expr.Sel)
sel.Node = expr
} else if inChain {
inChain = false
}
}
return out
}
fn := func(node ast.Node) {
expr := node.(*ast.SelectorExpr)
if _, ok := expr.X.(*ast.SelectorExpr); !ok {
// Avoid the expensive call to PathEnclosingInterval for the common 1-level deep selector, which cannot be shortened.
return
}
sels := extractSelectors(expr)
if len(sels) == 0 {
return
}
var edits []analysis.TextEdit
for _, sel := range sels {
fieldLoop:
for base, fields := pass.TypesInfo.TypeOf(sel.X), sel.Fields; len(fields) >= 2; base, fields = pass.TypesInfo.ObjectOf(fields[0]).Type(), fields[1:] {
hop1 := fields[0]
hop2 := fields[1]
// the selector expression might be a qualified identifier, which cannot be simplified
if base == types.Typ[types.Invalid] {
continue fieldLoop
}
// Check if we can skip a field in the chain of selectors.
// We can skip a field 'b' if a.b.c and a.c resolve to the same object and take the same path.
//
// We set addressable to true unconditionally because we've already successfully type-checked the program,
// which means either the selector doesn't need addressability, or it is addressable.
leftObj, leftLeg, _ := types.LookupFieldOrMethod(base, true, pass.Pkg, hop1.Name)
// We can't skip fields that aren't embedded
if !leftObj.(*types.Var).Embedded() {
continue fieldLoop
}
directObj, directPath, _ := types.LookupFieldOrMethod(base, true, pass.Pkg, hop2.Name)
// Fail fast if omitting the embedded field leads to a different object
if directObj != pass.TypesInfo.ObjectOf(hop2) {
continue fieldLoop
}
_, rightLeg, _ := types.LookupFieldOrMethod(leftObj.Type(), true, pass.Pkg, hop2.Name)
// Fail fast if the paths are obviously different
if len(directPath) != len(leftLeg)+len(rightLeg) {
continue fieldLoop
}
// Make sure that omitting the embedded field will take the same path to the final object.
// Multiple paths involving different fields may lead to the same type-checker object, causing different runtime behavior.
for i := range directPath {
if i < len(leftLeg) {
if leftLeg[i] != directPath[i] {
continue fieldLoop
}
} else {
if rightLeg[i-len(leftLeg)] != directPath[i] {
continue fieldLoop
}
}
}
e := edit.Delete(edit.Range{hop1.Pos(), hop2.Pos()})
edits = append(edits, e)
report.Report(pass, hop1, fmt.Sprintf("could remove embedded field %q from selector", hop1.Name),
report.Fixes(edit.Fix(fmt.Sprintf("Remove embedded field %q from selector", hop1.Name), e)))
}
}
// Offer to simplify all selector expressions at once
if len(edits) > 1 {
// Hack to prevent gopls from applying the Unnecessary tag to the diagnostic. It applies the tag when all edits are deletions.
edits = append(edits, edit.ReplaceWithString(edit.Range{node.Pos(), node.Pos()}, ""))
report.Report(pass, node, "could simplify selectors", report.Fixes(edit.Fix("Remove all embedded fields from selector", edits...)))
}
}
code.Preorder(pass, fn, (*ast.SelectorExpr)(nil))
return nil, nil
}
var timeEqualR = pattern.MustParse(`(CallExpr (SelectorExpr lhs (Ident "Equal")) rhs)`)
func CheckTimeEquality(pass *analysis.Pass) (interface{}, error) {
// FIXME(dh): create proper suggested fix for renamed import
fn := func(node ast.Node) {
expr := node.(*ast.BinaryExpr)
if expr.Op != token.EQL {
return
}
if !code.IsOfType(pass, expr.X, "time.Time") || !code.IsOfType(pass, expr.Y, "time.Time") {
return
}
report.Report(pass, node, "probably want to use time.Time.Equal instead",
report.Fixes(edit.Fix("Use time.Time.Equal method",
edit.ReplaceWithPattern(pass.Fset, node, timeEqualR, pattern.State{"lhs": expr.X, "rhs": expr.Y}))))
}
code.Preorder(pass, fn, (*ast.BinaryExpr)(nil))
return nil, nil
}
var byteSlicePrintingQ = pattern.MustParse(`
(Or
(CallExpr
(Function (Or
"fmt.Print"
"fmt.Println"
"fmt.Sprint"
"fmt.Sprintln"
"log.Fatal"
"log.Fatalln"
"log.Panic"
"log.Panicln"
"log.Print"
"log.Println"
"(*log.Logger).Fatal"
"(*log.Logger).Fatalln"
"(*log.Logger).Panic"
"(*log.Logger).Panicln"
"(*log.Logger).Print"
"(*log.Logger).Println")) args)
(CallExpr (Function (Or
"fmt.Fprint"
"fmt.Fprintln")) _:args))`)
var byteSlicePrintingR = pattern.MustParse(`(CallExpr (Ident "string") [arg])`)
func CheckByteSlicePrinting(pass *analysis.Pass) (interface{}, error) {
isStringer := func(T types.Type, ms *types.MethodSet) bool {
sel := ms.Lookup(nil, "String")
if sel == nil {
return false
}
fn, ok := sel.Obj().(*types.Func)
if !ok {
// should be unreachable
return false
}
sig := fn.Type().(*types.Signature)
if sig.Params().Len() != 0 {
return false
}
if sig.Results().Len() != 1 {
return false
}
if !typeutil.IsType(sig.Results().At(0).Type(), "string") {
return false
}
return true
}
fn := func(node ast.Node) {
m, ok := code.Match(pass, byteSlicePrintingQ, node)
if !ok {
return
}
args := m.State["args"].([]ast.Expr)
for _, arg := range args {
T := pass.TypesInfo.TypeOf(arg)
if typeutil.IsType(T.Underlying(), "[]byte") {
ms := types.NewMethodSet(T)
// don't convert arguments that implement fmt.Stringer
if isStringer(T, ms) {
continue
}
fix := edit.Fix("Convert argument to string", edit.ReplaceWithPattern(pass.Fset, arg, byteSlicePrintingR, pattern.State{"arg": arg}))
report.Report(pass, arg, "could convert argument to string", report.Fixes(fix))
}
}
}
code.Preorder(pass, fn, (*ast.CallExpr)(nil))
return nil, nil
}
var (
checkWriteBytesSprintfQ = pattern.MustParse(`
(CallExpr
(SelectorExpr recv (Ident "Write"))
(CallExpr (ArrayType nil (Ident "byte"))
(CallExpr
fn@(Or
(Function "fmt.Sprint")
(Function "fmt.Sprintf")
(Function "fmt.Sprintln"))
args)
))`)
checkWriteStringSprintfQ = pattern.MustParse(`
(CallExpr
(SelectorExpr recv (Ident "WriteString"))
(CallExpr
fn@(Or
(Function "fmt.Sprint")
(Function "fmt.Sprintf")
(Function "fmt.Sprintln"))
args))`)
writerInterface = types.NewInterfaceType([]*types.Func{
types.NewFunc(token.NoPos, nil, "Write", types.NewSignature(nil,
types.NewTuple(types.NewVar(token.NoPos, nil, "", types.NewSlice(types.Typ[types.Byte]))),
types.NewTuple(
types.NewVar(token.NoPos, nil, "", types.Typ[types.Int]),
types.NewVar(token.NoPos, nil, "", types.Universe.Lookup("error").Type()),
),
false,
)),
}, nil).Complete()
stringWriterInterface = types.NewInterfaceType([]*types.Func{
types.NewFunc(token.NoPos, nil, "WriteString", types.NewSignature(nil,
types.NewTuple(types.NewVar(token.NoPos, nil, "", types.Universe.Lookup("string").Type())),
types.NewTuple(
types.NewVar(token.NoPos, nil, "", types.Typ[types.Int]),
types.NewVar(token.NoPos, nil, "", types.Universe.Lookup("error").Type()),
),
false,
)),
}, nil).Complete()
)
func CheckWriteBytesSprintf(pass *analysis.Pass) (interface{}, error) {
fn := func(node ast.Node) {
if m, ok := code.Match(pass, checkWriteBytesSprintfQ, node); ok {
recv := m.State["recv"].(ast.Expr)
recvT := pass.TypesInfo.TypeOf(recv)
if !types.Implements(recvT, writerInterface) {
return
}
name := m.State["fn"].(*types.Func).Name()
newName := "F" + strings.TrimPrefix(name, "S")
msg := fmt.Sprintf("Use fmt.%s(...) instead of Write([]byte(fmt.%s(...)))", newName, name)
args := m.State["args"].([]ast.Expr)
fix := edit.Fix(msg, edit.ReplaceWithNode(pass.Fset, node, &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("fmt"),
Sel: ast.NewIdent(newName),
},
Args: append([]ast.Expr{recv}, args...),
}))
report.Report(pass, node, msg, report.Fixes(fix))
} else if m, ok := code.Match(pass, checkWriteStringSprintfQ, node); ok {
recv := m.State["recv"].(ast.Expr)
recvT := pass.TypesInfo.TypeOf(recv)
if !types.Implements(recvT, stringWriterInterface) {
return
}
// The type needs to implement both StringWriter and Writer.
// If it doesn't implement Writer, then we cannot pass it to fmt.Fprint.
if !types.Implements(recvT, writerInterface) {
return
}
name := m.State["fn"].(*types.Func).Name()
newName := "F" + strings.TrimPrefix(name, "S")
msg := fmt.Sprintf("Use fmt.%s(...) instead of WriteString(fmt.%s(...))", newName, name)
args := m.State["args"].([]ast.Expr)
fix := edit.Fix(msg, edit.ReplaceWithNode(pass.Fset, node, &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("fmt"),
Sel: ast.NewIdent(newName),
},
Args: append([]ast.Expr{recv}, args...),
}))
report.Report(pass, node, msg, report.Fixes(fix))
}
}
code.Preorder(pass, fn, (*ast.CallExpr)(nil))
return nil, nil
}