Skip to content

Commit

Permalink
Allow also directed channels in callback stubs
Browse files Browse the repository at this point in the history
Fixes previous commit which accidentally only allowed undirected channels.

#84
  • Loading branch information
petergtz committed May 13, 2019
1 parent 5d8ef84 commit e86ce18
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 18 deletions.
19 changes: 19 additions & 0 deletions dsl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ var (
BeTrue = gomega.BeTrue
ConsistOf = gomega.ConsistOf
ContainSubstring = gomega.ContainSubstring
MatchError = gomega.MatchError
Equal = gomega.Equal
Expect = gomega.Expect
HaveLen = gomega.HaveLen
Expand Down Expand Up @@ -894,6 +895,24 @@ var _ = Describe("MockDisplay", func() {
})
display.ChanReturnValues()
})

It("allows to return directed channels from callbacks", func() {
When(display.ChanReturnValues()).Then(func([]pegomock.Param) pegomock.ReturnValues {
return []ReturnValue{make(<-chan string), make(chan<- error)}
})
display.ChanReturnValues()
})

It("does not allow to return directed channels from callbacks with wrong direction", func() {
When(display.ChanReturnValues()).Then(func([]pegomock.Param) pegomock.ReturnValues {
return []ReturnValue{make(chan<- string), make(chan<- error)}
})

Expect(func() { display.ChanReturnValues() }).To(PanicWithMessageTo(MatchError(
"interface conversion: pegomock.ReturnValue is chan<- string, not <-chan string",
)))

})
})

Context("using send-/receive-only channels", func() {
Expand Down
45 changes: 27 additions & 18 deletions mockgen/mockgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (g *generator) generateMockFor(iface *model.Interface, mockTypeName, selfPa
g.generateVerifierType(mockTypeName)
for _, method := range iface.Methods {
ongoingVerificationTypeName := fmt.Sprintf("%v_%v_OngoingVerification", mockTypeName, method.Name)
args, argNames, argTypes, _, _ := argDataFor(method, g.packageMap, selfPackage)
args, argNames, argTypes, _ := argDataFor(method, g.packageMap, selfPackage)
g.generateVerifierMethod(mockTypeName, method, selfPackage, ongoingVerificationTypeName, args, argNames)
g.generateOngoingVerificationType(mockTypeName, ongoingVerificationTypeName)
g.generateOngoingVerificationGetCapturedArguments(ongoingVerificationTypeName, argNames, argTypes)
Expand Down Expand Up @@ -185,15 +185,15 @@ func (g *generator) generateMockType(mockTypeName string) {

// If non-empty, pkgOverride is the package in which unqualified types reside.
func (g *generator) generateMockMethod(mockType string, method *model.Method, pkgOverride string) *generator {
args, argNames, _, signatureReturnTypes, returnTypes := argDataFor(method, g.packageMap, pkgOverride)
g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(signatureReturnTypes))
args, argNames, _, returnTypes := argDataFor(method, g.packageMap, pkgOverride)
g.p("func (mock *%v) %v(%v) (%v) {", mockType, method.Name, join(args), join(stringSliceFrom(returnTypes, g.packageMap, pkgOverride)))
g.p("if mock == nil {").
p(" panic(\"mock must not be nil. Use myMock := New%v().\")", mockType).
p("}")
g.GenerateParamsDeclaration(argNames, method.Variadic != nil)
reflectReturnTypes := make([]string, len(returnTypes))
for i, returnType := range returnTypes {
reflectReturnTypes[i] = fmt.Sprintf("reflect.TypeOf((*%v)(nil)).Elem()", returnType)
reflectReturnTypes[i] = fmt.Sprintf("reflect.TypeOf((*%v)(nil)).Elem()", returnType.String(g.packageMap, pkgOverride))
}
resultAssignment := ""
if len(method.Out) > 0 {
Expand All @@ -204,13 +204,23 @@ func (g *generator) generateMockMethod(mockType string, method *model.Method, pk
if len(method.Out) > 0 {
// TODO: translate LastInvocation into a Matcher so it can be used as key for Stubbings
for i, returnType := range returnTypes {
g.p("var ret%v %v", i, returnType)
g.p("var ret%v %v", i, returnType.String(g.packageMap, pkgOverride))
}
g.p("if len(result) != 0 {")
returnValues := make([]string, len(returnTypes))
for i, returnType := range returnTypes {
g.p("if result[%v] != nil {", i)
g.p("ret%v = result[%v].(%v)", i, i, returnType)
if chanType, isChanType := returnType.(*model.ChanType); isChanType && chanType.Dir != 0 {
undirectedChanType := *chanType
undirectedChanType.Dir = 0
g.p("var ok bool").
p(" ret%v, ok = result[%v].(%v)", i, i, undirectedChanType.String(g.packageMap, pkgOverride))
g.p("if !ok{").
p("ret%v = result[%v].(%v)", i, i, chanType.String(g.packageMap, pkgOverride)).
p("}")
} else {
g.p("ret%v = result[%v].(%v)", i, i, returnType.String(g.packageMap, pkgOverride))
}
g.p("}")
returnValues[i] = fmt.Sprintf("ret%v", i)
}
Expand Down Expand Up @@ -353,8 +363,7 @@ func argDataFor(method *model.Method, packageMap map[string]string, pkgOverride
args []string,
argNames []string,
argTypes []string,
signatureReturnTypes []string,
returnTypes []string,
returnTypes []model.Type,
) {
args = make([]string, len(method.In))
argNames = make([]string, len(method.In))
Expand All @@ -379,21 +388,21 @@ func argDataFor(method *model.Method, packageMap map[string]string, pkgOverride
argNames = append(argNames, argName)
argTypes = append(argTypes, "[]"+argType)
}
signatureReturnTypes = make([]string, len(method.Out))
returnTypes = make([]string, len(method.Out))
returnTypes = make([]model.Type, len(method.Out))
for i, ret := range method.Out {
if chanType, isChanType := ret.Type.(*model.ChanType); isChanType {
chanTypeNoDir := *chanType
chanTypeNoDir.Dir = 0
returnTypes[i] = chanTypeNoDir.String(packageMap, pkgOverride)
} else {
returnTypes[i] = ret.Type.String(packageMap, pkgOverride)
}
signatureReturnTypes[i] = ret.Type.String(packageMap, pkgOverride)
returnTypes[i] = ret.Type
}
return
}

func stringSliceFrom(types []model.Type, packageMap map[string]string, pkgOverride string) []string {
result := make([]string, len(types))
for i, t := range types {
result[i] = t.String(packageMap, pkgOverride)
}
return result
}

func addTypesFromMethodParamsTo(typesSet map[string]string, params []*model.Parameter, packageMap map[string]string) {
for _, param := range params {
switch typedType := param.Type.(type) {
Expand Down

0 comments on commit e86ce18

Please sign in to comment.