Skip to content

Commit

Permalink
Support generic struct types
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronjwood committed Oct 18, 2022
1 parent c585660 commit fd46275
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 53 deletions.
9 changes: 4 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module github.com/jfeliu007/goplantuml

go 1.17
go 1.18

require (
github.com/spf13/afero v1.8.2
golang.org/x/text v0.3.7 // indirect
)
require github.com/spf13/afero v1.8.2

require golang.org/x/text v0.3.7 // indirect
46 changes: 30 additions & 16 deletions parser/class_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ call the Render() function and this will return a string with the class diagram.
See github.com/jfeliu007/goplantuml/cmd/goplantuml/main.go for a command that uses this functions and outputs the text to
the console.
*/
package parser

Expand Down Expand Up @@ -267,7 +266,6 @@ func (p *ClassParser) parseFileDeclarations(node ast.Decl) {
}

func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {

if decl.Recv != nil {
if decl.Recv.List == nil {
return
Expand Down Expand Up @@ -296,24 +294,30 @@ func (p *ClassParser) handleFuncDecl(decl *ast.FuncDecl) {
}
}

func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType) {
func handleGenDecStructType(p *ClassParser, typeName string, c *ast.StructType, typeParams *ast.FieldList) {
for _, f := range c.Fields.List {
p.getOrCreateStruct(typeName).AddField(f, p.allImports)
}

if typeParams == nil {
return
}

for _, tp := range typeParams.List {
p.getOrCreateStruct(typeName).AddTypeParam(tp)
}
}

func handleGenDecInterfaceType(p *ClassParser, typeName string, c *ast.InterfaceType) {
for _, f := range c.Methods.List {
switch t := f.Type.(type) {
case *ast.FuncType:
p.getOrCreateStruct(typeName).AddMethod(f, p.allImports)
break
case *ast.Ident:
f, _ := getFieldType(t, p.allImports)
st := p.getOrCreateStruct(typeName)
f = replacePackageConstant(f, st.PackageName)
st.AddToComposition(f)
break
}
}
}
Expand All @@ -338,7 +342,7 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
switch c := v.Type.(type) {
case *ast.StructType:
declarationType = "class"
handleGenDecStructType(p, typeName, c)
handleGenDecStructType(p, typeName, c, v.TypeParams)
case *ast.InterfaceType:
declarationType = "interface"
handleGenDecInterfaceType(p, typeName, c)
Expand Down Expand Up @@ -379,7 +383,6 @@ func (p *ClassParser) processSpec(spec ast.Spec) {
p.allRenamedStructs[pack[0]][renamedClass] = pack[1]
}
}
return
}

// If this element is an array or a pointer, this function will return the type that is closer to these
Expand Down Expand Up @@ -465,7 +468,7 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
str.WriteLineWithDepth(2, aliasComplexNameComment)
str.WriteLineWithDepth(1, "}")
}
str.WriteLineWithDepth(0, fmt.Sprintf(`}`))
str.WriteLineWithDepth(0, `}`)
if p.renderingOptions.Compositions {
str.WriteLineWithDepth(0, composition.String())
}
Expand All @@ -479,7 +482,6 @@ func (p *ClassParser) renderStructures(pack string, structures map[string]*Struc
}

func (p *ClassParser) renderAliases(str *LineStringBuilder) {

aliasString := ""
if p.renderingOptions.ConnectionLabels {
aliasString = aliasOf
Expand All @@ -505,7 +507,6 @@ func (p *ClassParser) renderAliases(str *LineStringBuilder) {
}

func (p *ClassParser) renderStructure(structure *Struct, pack string, name string, str *LineStringBuilder, composition *LineStringBuilder, extends *LineStringBuilder, aggregations *LineStringBuilder) {

privateFields := &LineStringBuilder{}
publicFields := &LineStringBuilder{}
privateMethods := &LineStringBuilder{}
Expand All @@ -518,9 +519,24 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
case "alias":
sType = "<< (T, #FF7700) >> "
renderStructureType = "class"
}

types := ""
if structure.Generics.exists() {
types = "<"
for t := range structure.Generics.Types {
types += fmt.Sprintf("%s, ", t)
}
types = strings.TrimSuffix(types, ", ")
types += " constrains "
for _, n := range structure.Generics.Names {
types += fmt.Sprintf("%s, ", n)
}
types = strings.TrimSuffix(types, ", ")
types += ">"
}
str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s %s {`, renderStructureType, name, sType))

str.WriteLineWithDepth(1, fmt.Sprintf(`%s %s%s %s {`, renderStructureType, name, types, sType))
p.renderStructFields(structure, privateFields, publicFields)
p.renderStructMethods(structure, privateMethods, publicMethods)
p.renderCompositions(structure, name, composition)
Expand All @@ -538,7 +554,7 @@ func (p *ClassParser) renderStructure(structure *Struct, pack string, name strin
if publicMethods.Len() > 0 {
str.WriteLineWithDepth(0, publicMethods.String())
}
str.WriteLineWithDepth(1, fmt.Sprintf(`}`))
str.WriteLineWithDepth(1, `}`)
}

func (p *ClassParser) renderCompositions(structure *Struct, name string, composition *LineStringBuilder) {
Expand All @@ -562,7 +578,6 @@ func (p *ClassParser) renderCompositions(structure *Struct, name string, composi
}

func (p *ClassParser) renderAggregations(structure *Struct, name string, aggregations *LineStringBuilder) {

aggregationMap := structure.Aggregations
if p.renderingOptions.AggregatePrivateMembers {
p.updatePrivateAggregations(structure, aggregationMap)
Expand All @@ -571,7 +586,6 @@ func (p *ClassParser) renderAggregations(structure *Struct, name string, aggrega
}

func (p *ClassParser) updatePrivateAggregations(structure *Struct, aggregationsMap map[string]struct{}) {

for agg := range structure.PrivateAggregations {
aggregationsMap[agg] = struct{}{}
}
Expand Down Expand Up @@ -600,13 +614,13 @@ func (p *ClassParser) renderAggregationMap(aggregationMap map[string]struct{}, s
}

func (p *ClassParser) getPackageName(t string, st *Struct) string {

packageName := st.PackageName
if isPrimitiveString(t) {
packageName = builtinPackageName
}
return packageName
}

func (p *ClassParser) renderExtends(structure *Struct, name string, extends *LineStringBuilder) {

orderedExtends := []string{}
Expand All @@ -628,7 +642,6 @@ func (p *ClassParser) renderExtends(structure *Struct, name string, extends *Lin
}

func (p *ClassParser) renderStructMethods(structure *Struct, privateMethods *LineStringBuilder, publicMethods *LineStringBuilder) {

for _, method := range structure.Functions {
accessModifier := "+"
if unicode.IsLower(rune(method.Name[0])) {
Expand Down Expand Up @@ -685,6 +698,7 @@ func (p *ClassParser) getOrCreateStruct(name string) *Struct {
Functions: make([]*Function, 0),
Fields: make([]*Field, 0),
Type: "",
Generics: NewGeneric(),
Composition: make(map[string]struct{}, 0),
Extends: make(map[string]struct{}, 0),
Aggregations: make(map[string]struct{}, 0),
Expand Down
9 changes: 5 additions & 4 deletions parser/class_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package parser

import (
"go/ast"
"io/ioutil"
"os"
"reflect"
"testing"
)
Expand Down Expand Up @@ -94,6 +94,7 @@ func TestGetOrCreateStruct(t *testing.T) {
Functions: make([]*Function, 0),
Fields: make([]*Field, 0),
Type: "",
Generics: NewGeneric(),
Composition: make(map[string]struct{}, 0),
Extends: make(map[string]struct{}, 0),
Aggregations: make(map[string]struct{}, 0),
Expand Down Expand Up @@ -181,7 +182,6 @@ func TestRenderStructFields(t *testing.T) {
}

func TestRenderStructures(t *testing.T) {

structMap := map[string]*Struct{
"MainClass": getTestStruct(),
}
Expand Down Expand Up @@ -296,6 +296,7 @@ func getTestStruct() *Struct {
ReturnValues: []string{"int"},
},
},
Generics: NewGeneric(),
}
}

Expand Down Expand Up @@ -563,7 +564,7 @@ func TestRender(t *testing.T) {
})

resultRender := parser.Render()
result, err := ioutil.ReadFile("../testingsupport/testingsupport.puml")
result, err := os.ReadFile("../testingsupport/testingsupport.puml")
if err != nil {
t.Errorf("TestRender: expected no errors reading testing file, got %s", err.Error())
}
Expand Down Expand Up @@ -592,7 +593,7 @@ func TestMultipleFolders(t *testing.T) {
}

resultRender := parser.Render()
result, err := ioutil.ReadFile("../testingsupport/subfolder1-2.puml")
result, err := os.ReadFile("../testingsupport/subfolder1-2.puml")
if err != nil {
t.Errorf("TestMultipleFolders: expected no errors reading testing file, got %s", err.Error())
}
Expand Down
18 changes: 4 additions & 14 deletions parser/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (

const packageConstant = "{packageName}"

//Field can hold the name and type of any field
// Field can hold the name and type of any field
type Field struct {
Name string
Type string
FullType string
}

//Returns a string representation of the given expression if it was recognized.
//Refer to the implementation to see the different string representations.
// Returns a string representation of the given expression if it was recognized.
// Refer to the implementation to see the different string representations.
func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
switch v := exp.(type) {
case *ast.Ident:
Expand Down Expand Up @@ -45,7 +45,6 @@ func getFieldType(exp ast.Expr, aliases map[string]string) (string, []string) {
}

func getIdent(v *ast.Ident, aliases map[string]string) (string, []string) {

if isPrimitive(v) {
return v.Name, []string{}
}
Expand All @@ -59,7 +58,6 @@ func getArrayType(v *ast.ArrayType, aliases map[string]string) (string, []string
}

func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []string) {

packageName := v.X.(*ast.Ident).Name
if realPackageName, ok := aliases[packageName]; ok {
packageName = realPackageName
Expand All @@ -69,26 +67,22 @@ func getSelectorExp(v *ast.SelectorExpr, aliases map[string]string) (string, []s
}

func getMapType(v *ast.MapType, aliases map[string]string) (string, []string) {

t1, f1 := getFieldType(v.Key, aliases)
t2, f2 := getFieldType(v.Value, aliases)
return fmt.Sprintf("<font color=blue>map</font>[%s]%s", t1, t2), append(f1, f2...)
}

func getStarExp(v *ast.StarExpr, aliases map[string]string) (string, []string) {

t, f := getFieldType(v.X, aliases)
return fmt.Sprintf("*%s", t), f
}

func getChanType(v *ast.ChanType, aliases map[string]string) (string, []string) {

t, f := getFieldType(v.Value, aliases)
return fmt.Sprintf("<font color=blue>chan</font> %s", t), f
}

func getStructType(v *ast.StructType, aliases map[string]string) (string, []string) {

fieldList := make([]string, 0)
for _, field := range v.Fields.List {
t, _ := getFieldType(field.Type, aliases)
Expand All @@ -98,7 +92,6 @@ func getStructType(v *ast.StructType, aliases map[string]string) (string, []stri
}

func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string, []string) {

methods := make([]string, 0)
for _, field := range v.Methods.List {
methodName := ""
Expand All @@ -112,17 +105,14 @@ func getInterfaceType(v *ast.InterfaceType, aliases map[string]string) (string,
}

func getFuncType(v *ast.FuncType, aliases map[string]string) (string, []string) {

function := getFunction(v, "", aliases, "")
params := make([]string, 0)
for _, pa := range function.Parameters {
params = append(params, pa.Type)
}
returns := ""
returnList := make([]string, 0)
for _, re := range function.ReturnValues {
returnList = append(returnList, re)
}
returnList = append(returnList, function.ReturnValues...)
if len(returnList) > 1 {
returns = fmt.Sprintf("(%s)", strings.Join(returnList, ", "))
} else {
Expand Down
Loading

0 comments on commit fd46275

Please sign in to comment.