Skip to content

Commit 702c1bf

Browse files
tchajedupamanyus
andauthored
proofgen: generate new definitions needed for initialization support (#65)
This PR adds generation of PkgInfo, PkgIsDefined, and PkgIsInitialized instances. It also shifts all the WpFuncCall and related instances to be based on is_pkg_defined to be systematic. At the same time, we change `pkg_name'` to be the name of the package itself. There is some subtlety here: if a package has an identifier of the same name (e.g., package `raft` has a type `raft`, or package `main` has a function `main`), then it must be referenced with a package qualifier. --------- Co-authored-by: Upamanyu Sharma <[email protected]>
1 parent 17b024e commit 702c1bf

File tree

21 files changed

+808
-607
lines changed

21 files changed

+808
-607
lines changed

glang/coq.go

+52
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,32 @@ func (d ConstDecl) DefName() (bool, string) {
899899
return true, d.Name
900900
}
901901

902+
type InstanceDecl struct {
903+
Type Expr
904+
// If not global, instance will be export
905+
Global bool
906+
Body Expr
907+
// Can be empty (instance gets an automatic name in Coq)
908+
Name string
909+
}
910+
911+
func (d InstanceDecl) CoqDecl() string {
912+
var pp buffer
913+
qualifier := "#[export]"
914+
if d.Global {
915+
qualifier = "#[global]"
916+
}
917+
pp.Add("%s Instance %s : %s :=",
918+
qualifier, d.Name, d.Type.Coq(false))
919+
pp.Indent(2)
920+
pp.Add("%s.", d.Body.Coq(false))
921+
return pp.Build()
922+
}
923+
924+
func (d InstanceDecl) DefName() (bool, string) {
925+
return true, d.Name
926+
}
927+
902928
type AxiomDecl struct {
903929
DeclName string
904930
Type Expr
@@ -989,6 +1015,32 @@ func (decls ImportDecls) PrintImports() string {
9891015
return strings.Join(ss, "\n")
9901016
}
9911017

1018+
type RecordField struct {
1019+
Name string
1020+
Value Expr
1021+
}
1022+
1023+
// RecordLiteral represents a Gallina record literal
1024+
type RecordLiteral struct {
1025+
Fields []RecordField
1026+
}
1027+
1028+
func (r RecordField) Coq(needs_paren bool) string {
1029+
return fmt.Sprintf("%s := %s", r.Name, r.Value.Coq(needs_paren))
1030+
}
1031+
1032+
func (r RecordLiteral) Coq(needs_paren bool) string {
1033+
var pp buffer
1034+
pp.AddLine("{|")
1035+
pp.Indent(2)
1036+
for _, field := range r.Fields {
1037+
pp.Add("%s;", field.Coq(false))
1038+
}
1039+
pp.Indent(-2)
1040+
pp.AddLine("|}")
1041+
return addParens(needs_paren, pp.Build())
1042+
}
1043+
9921044
// File represents a complete Coq file (a sequence of declarations).
9931045
type File struct {
9941046
ImportHeader string

go.work.sum

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
2121
golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys=
2222
golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE=
2323
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
24+
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
2425
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
2526
golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
2627
golang.org/x/telemetry v0.0.0-20240521205824-bda55230c457 h1:zf5N6UOrA487eEFacMePxjXAJctxKmyjKUsjA11Uzuk=

goose.go

+63-24
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ type Ctx struct {
3232
namesToTranslate map[string]bool
3333
info *types.Info
3434
pkgPath string
35+
36+
// XXX: Initially tried using `pkg.Name` as the Gallina identifier holding
37+
// the full package path, but that doesn't work in a `package main` with a `func main`.
38+
// In that case, as soon as `func main` is defined inside `Module main.`,
39+
// reference to simply `main` (which should be the go_string holding the
40+
// package path) end up referring to the function. So, this uses `filename +
41+
// "." + pkg.Name` to refer to the Gallina definition that holds a package's
42+
// full path as a go_string (e.g. in `globals_test`, instead of `func_call #main ...`,
43+
// this results in `func_call #globals_test.main ...`).
44+
pkgIdent string
3545
errorReporter
3646

3747
// XXX: this is so we can determine the expected return type when handling a
@@ -87,9 +97,12 @@ func getFfi(pkg *packages.Package) string {
8797

8898
// NewPkgCtx initializes a context based on a properly loaded package
8999
func NewPkgCtx(pkg *packages.Package, filter declfilter.DeclFilter) Ctx {
100+
ss := strings.Split(pkg.PkgPath, "/")
101+
pkgIdent := ss[len(ss)-1]
90102
return Ctx{
91103
info: pkg.TypesInfo,
92104
pkgPath: pkg.PkgPath,
105+
pkgIdent: pkgIdent + "." + pkg.Name,
93106
errorReporter: newErrorReporter(pkg.Fset),
94107
dep: newDepTracker(),
95108
importNames: make(map[string]struct{}),
@@ -646,26 +659,28 @@ func (ctx *Ctx) qualifiedName(obj types.Object) string {
646659
}
647660

648661
func (ctx *Ctx) getPkgAndName(obj types.Object) (pkg string, name string) {
662+
if obj.Pkg() == nil {
663+
ctx.unsupported(obj, "expected object to have package")
664+
}
649665
name = obj.Name()
650-
pkg = "pkg_name'"
651-
if obj.Pkg() == nil || ctx.pkgPath == obj.Pkg().Path() {
652-
return
666+
if obj.Pkg().Path() == ctx.pkgPath {
667+
pkg = ctx.pkgIdent
668+
} else {
669+
pkg = obj.Pkg().Name()
653670
}
654-
pkg = obj.Pkg().Name() + "." + pkg
655671
return
656672
}
657673

658674
func (ctx *Ctx) selectorExprAddr(e *ast.SelectorExpr) glang.Expr {
659675
selection := ctx.info.Selections[e]
660676
if selection == nil {
661-
pkg, ok := getIdent(e.X)
677+
pkgName, ok := getIdent(e.X)
662678
if !ok {
663679
ctx.unsupported(e, "expected package selector with idtent, got %T", e.X)
664680
}
665681
if _, ok := ctx.info.ObjectOf(e.Sel).(*types.Var); ok {
666-
ctx.dep.addDep("pkg_name'")
667682
return glang.NewCallExpr(glang.GallinaIdent("globals.get"),
668-
glang.StringVal{Value: glang.GallinaIdent(pkg + ".pkg_name'")},
683+
glang.StringVal{Value: glang.GallinaIdent(pkgName)},
669684
glang.StringVal{Value: glang.StringLiteral{Value: e.Sel.Name}},
670685
)
671686
} else {
@@ -858,7 +873,7 @@ func (ctx *Ctx) selectorExpr(e *ast.SelectorExpr) glang.Expr {
858873
ctx.info.TypeOf(e),
859874
glang.NewCallExpr(
860875
glang.GallinaIdent("func_call"),
861-
glang.StringVal{Value: glang.GallinaIdent(f.Pkg().Name() + ".pkg_name'")},
876+
glang.StringVal{Value: glang.GallinaIdent(f.Pkg().Name())},
862877
glang.StringVal{Value: glang.StringLiteral{Value: e.Sel.Name}},
863878
),
864879
)
@@ -1307,12 +1322,13 @@ func (ctx *Ctx) unaryExpr(e *ast.UnaryExpr, isSpecial bool) glang.Expr {
13071322
}
13081323

13091324
func (ctx *Ctx) function(s *ast.Ident) glang.Expr {
1310-
ctx.dep.addDep("pkg_name'")
1311-
13121325
typeArgs := ctx.info.Instances[s].TypeArgs
13131326
if typeArgs.Len() == 0 {
1327+
if ctx.info.ObjectOf(s).Pkg().Path() != ctx.pkgPath {
1328+
ctx.nope(s, "expected function ident to refer to local function")
1329+
}
13141330
return glang.NewCallExpr(glang.GallinaIdent("func_call"),
1315-
glang.StringVal{Value: glang.GallinaIdent("pkg_name'")},
1331+
glang.StringVal{Value: glang.GallinaIdent(ctx.pkgIdent)},
13161332
glang.StringVal{Value: glang.StringLiteral{Value: s.Name}},
13171333
)
13181334
}
@@ -1904,9 +1920,14 @@ func (ctx *Ctx) exprAddr(e ast.Expr) glang.Expr {
19041920
obj := ctx.info.ObjectOf(e)
19051921
if _, ok := obj.(*types.Var); ok {
19061922
if obj.Pkg().Scope() == obj.Parent() {
1907-
ctx.dep.addDep("pkg_name'")
1923+
pkgIdent := ""
1924+
if obj.Pkg().Path() == ctx.pkgPath {
1925+
pkgIdent = ctx.pkgIdent
1926+
} else {
1927+
pkgIdent = obj.Pkg().Name()
1928+
}
19081929
return glang.NewCallExpr(glang.GallinaIdent("globals.get"),
1909-
glang.StringVal{Value: glang.GallinaIdent("pkg_name'")},
1930+
glang.StringVal{Value: glang.GallinaIdent(pkgIdent)},
19101931
glang.StringVal{Value: glang.StringLiteral{Value: e.Name}},
19111932
)
19121933
} else {
@@ -2827,12 +2848,6 @@ func (ctx *Ctx) decl(d ast.Decl) []glang.Decl {
28272848

28282849
func (ctx *Ctx) initFunctions() []glang.Decl {
28292850
var decls = []glang.Decl{}
2830-
nameDecl := glang.ConstDecl{
2831-
Name: "pkg_name'",
2832-
Val: glang.GallinaString(ctx.pkgPath),
2833-
Type: glang.GallinaIdent("go_string"),
2834-
}
2835-
decls = append(decls, nameDecl)
28362851

28372852
ctx.dep.setCurrentName("initialize'")
28382853
initFunc := glang.FuncDecl{Name: "initialize'"}
@@ -2882,6 +2897,33 @@ func (ctx *Ctx) initFunctions() []glang.Decl {
28822897
}
28832898
decls = append(decls, msetsDecl)
28842899

2900+
var imports glang.ListExpr
2901+
for _, impName := range ctx.importNamesOrdered {
2902+
imports = append(imports, glang.GallinaIdent(fmt.Sprintf("%s", impName)))
2903+
}
2904+
infoRecord := glang.RecordLiteral{
2905+
Fields: []glang.RecordField{
2906+
{Name: "pkg_vars",
2907+
Value: glang.GallinaIdent("vars'")},
2908+
{Name: "pkg_functions",
2909+
Value: glang.GallinaIdent("functions'")},
2910+
{Name: "pkg_msets",
2911+
Value: glang.GallinaIdent("msets'")},
2912+
{Name: "pkg_imported_pkgs",
2913+
Value: imports},
2914+
},
2915+
}
2916+
2917+
infoInstanceDecl := glang.InstanceDecl{
2918+
Type: glang.NewCallExpr(glang.GallinaIdent("PkgInfo"),
2919+
glang.GallinaIdent(ctx.pkgIdent),
2920+
),
2921+
Global: true,
2922+
Body: infoRecord,
2923+
Name: "info'", // no name required
2924+
}
2925+
decls = append(decls, infoInstanceDecl)
2926+
28852927
var e glang.Expr
28862928

28872929
// add all init() function bodies
@@ -2919,7 +2961,7 @@ InitLoop:
29192961
e = glang.NewDoSeq(
29202962
glang.StoreStmt{
29212963
Dst: glang.NewCallExpr(glang.GallinaIdent("globals.get"),
2922-
glang.StringVal{Value: glang.GallinaIdent("pkg_name'")},
2964+
glang.StringVal{Value: glang.GallinaIdent(ctx.pkgIdent)},
29232965
glang.StringVal{Value: glang.StringLiteral{Value: init.Lhs[i-1].Name()}},
29242966
),
29252967
X: glang.IdentExpr(fmt.Sprintf("$r%d", i-1)),
@@ -2988,10 +3030,7 @@ InitLoop:
29883030

29893031
e = glang.NewCallExpr(glang.GallinaIdent("exception_do"), e)
29903032
e = glang.NewCallExpr(glang.GallinaIdent("globals.package_init"),
2991-
glang.GallinaIdent("pkg_name'"),
2992-
glang.GallinaIdent("vars'"),
2993-
glang.GallinaIdent("functions'"),
2994-
glang.GallinaIdent("msets'"),
3033+
glang.GallinaIdent(ctx.pkgIdent),
29953034
glang.FuncLit{Args: nil, Body: e},
29963035
)
29973036

interface.go

+1
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ func translatePackage(pkg *packages.Package, configContents []byte) (glang.File,
240240

241241
func (ctx *Ctx) ffiHeaderFooter(pkg *packages.Package) (header string, footer string) {
242242
ffi := getFfi(pkg)
243+
header += fmt.Sprintf("Definition %s : go_string := \"%s\".\n\n", pkg.Name, pkg.PkgPath)
243244
if ffi == "none" {
244245
header += fmt.Sprintf("Module %s.\n", pkg.Name)
245246
header += "Section code.\n" +

proofgen/imports.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func (tr *importsTranslator) Decl(d ast.Decl) {
4242
}
4343
}
4444

45-
func translateImports(w io.Writer, pkg *packages.Package, usingFfi bool, ffi string, filter declfilter.DeclFilter) {
45+
func translateImports(w io.Writer, pkg *packages.Package, usingFfi bool, ffi string, filter declfilter.DeclFilter) *importsTranslator {
4646
tr := &importsTranslator{
4747
importsSet: make(map[string]struct{}),
4848
filter: filter,
@@ -55,4 +55,5 @@ func translateImports(w io.Writer, pkg *packages.Package, usingFfi bool, ffi str
5555
for _, imp := range tr.importsList {
5656
fmt.Fprintf(w, "Require Export New.generatedproof.%s.\n", imp)
5757
}
58+
return tr
5859
}

proofgen/names.go

+12-9
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,13 @@ Class GlobalAddrs :=
9696
}
9797
fmt.Fprintln(w, "\n ].")
9898

99-
// emit `is_defined`
99+
// emit `PkgIsDefined instance`
100100
fmt.Fprintf(w, `
101-
Definition is_defined := is_global_definitions %s.pkg_name' var_addrs %s.functions' %s.msets'.
102-
`, pkg.Name, pkg.Name, pkg.Name)
101+
Global Instance is_pkg_defined_instance : IsPkgDefined %s :=
102+
{|
103+
is_pkg_defined := is_global_definitions %s var_addrs;
104+
|}.
105+
`, pkg.Name, pkg.Name)
103106

104107
// emit `own_allocated`
105108
fmt.Fprint(w, "\nDefinition own_allocated `{!GlobalAddrs} : iProp Σ :=\n")
@@ -121,7 +124,7 @@ Definition is_defined := is_global_definitions %s.pkg_name' var_addrs %s.functio
121124
// emit instances for global.get
122125
for _, varName := range tr.varNames {
123126
fmt.Fprintf(w, "\nGlobal Instance wp_globals_get_%s : \n", varName)
124-
fmt.Fprintf(w, " WpGlobalsGet %s.pkg_name' \"%s\" %s is_defined.\n", pkg.Name, varName, varName)
127+
fmt.Fprintf(w, " WpGlobalsGet %s \"%s\" %s (is_pkg_defined %s).\n", pkg.Name, varName, varName, pkg.Name)
125128
fmt.Fprintf(w, "Proof. apply wp_globals_get'. reflexivity. Qed.\n")
126129
}
127130

@@ -131,7 +134,7 @@ Definition is_defined := is_global_definitions %s.pkg_name' var_addrs %s.functio
131134
continue
132135
}
133136
fmt.Fprintf(w, "\nGlobal Instance wp_func_call_%s :\n", funcName)
134-
fmt.Fprintf(w, " WpFuncCall %s.pkg_name' \"%s\" _ is_defined :=\n", pkg.Name, funcName)
137+
fmt.Fprintf(w, " WpFuncCall %s \"%s\" _ (is_pkg_defined %s) :=\n", pkg.Name, funcName, pkg.Name)
135138
fmt.Fprintf(w, " ltac:(apply wp_func_call'; reflexivity).\n")
136139
}
137140

@@ -147,8 +150,8 @@ Definition is_defined := is_global_definitions %s.pkg_name' var_addrs %s.functio
147150
}
148151

149152
fmt.Fprintf(w, "\nGlobal Instance wp_method_call_%s_%s :\n", typeName, methodName)
150-
fmt.Fprintf(w, " WpMethodCall %s.pkg_name' \"%s\" \"%s\" _ is_defined :=\n",
151-
pkg.Name, typeName, methodName)
153+
fmt.Fprintf(w, " WpMethodCall %s \"%s\" \"%s\" _ (is_pkg_defined %s) :=\n",
154+
pkg.Name, typeName, methodName, pkg.Name)
152155
fmt.Fprintf(w, " ltac:(apply wp_method_call'; reflexivity).\n")
153156
// XXX: by using an ltac expression to generate the instance, we can
154157
// leave an evar for the method val, avoiding the need to write out
@@ -164,8 +167,8 @@ Definition is_defined := is_global_definitions %s.pkg_name' var_addrs %s.functio
164167
}
165168

166169
fmt.Fprintf(w, "\nGlobal Instance wp_method_call_%s_%s :\n", typeName, methodName)
167-
fmt.Fprintf(w, " WpMethodCall %s.pkg_name' \"%s\" \"%s\" _ is_defined :=\n",
168-
pkg.Name, typeName, methodName)
170+
fmt.Fprintf(w, " WpMethodCall %s \"%s\" \"%s\" _ (is_pkg_defined %s) :=\n",
171+
pkg.Name, typeName, methodName, pkg.Name)
169172
fmt.Fprintf(w, " ltac:(apply wp_method_call'; reflexivity).\n")
170173
}
171174
}

proofgen/proofgen.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
func Package(w io.Writer, pkg *packages.Package, usingFfi bool, ffi string, filter declfilter.DeclFilter) {
14-
fmt.Fprintf(w, "(* autogenerated by goose proofgen (types); do not modify *)\n")
14+
fmt.Fprintf(w, "(* autogenerated by goose proofgen; do not modify *)\n")
1515

1616
if usingFfi {
1717
fmt.Fprintf(w, "Require Export New.proof.%s_prelude.\n", ffi)
@@ -26,8 +26,8 @@ func Package(w io.Writer, pkg *packages.Package, usingFfi bool, ffi string, filt
2626

2727
translateImports(w, pkg, usingFfi, ffi, filter)
2828

29-
fmt.Fprintf(w, "Require Export New.code.%s.\n", coqPath)
3029
fmt.Fprintf(w, "Require Export New.golang.theory.\n\n")
30+
fmt.Fprintf(w, "Require Export New.code.%s.\n", coqPath)
3131

3232
fmt.Fprintf(w, "Module %s.\n", pkg.Name)
3333

proofgen/types.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ Context `+"`"+`{!ffi_model, !ffi_semantics _ _, !ffi_interp _, !heapGS Σ}.
241241
fmt.Fprintf(w, "%s\n \"%s\" ::= #%s", sep, getFieldName(i), toCoqName(getFieldName(i)))
242242
sep = ";"
243243
}
244-
fmt.Fprint(w, "\n ]))%%V\n #(")
244+
fmt.Fprintf(w, "%s", "\n ]))"+`%struct`+"\n #(")
245245
fmt.Fprintf(w, "%s.mk", name)
246246
for i := 0; i < s.NumFields(); i++ {
247247
fmt.Fprintf(w, " %s", toCoqName(getFieldName(i)))
@@ -341,7 +341,7 @@ func translateTypes(w io.Writer, pkg *packages.Package, usingFfi bool, ffi strin
341341
for _, depName := range tr.deps[n] {
342342
printDefAndDeps(depName)
343343
}
344-
fmt.Fprintf(w, tr.defs[n])
344+
fmt.Fprint(w, tr.defs[n])
345345
printed[n] = true
346346
}
347347
for _, d := range tr.defNames {

0 commit comments

Comments
 (0)