From 79544cc282a224e1fca0b813791c92954919e3a0 Mon Sep 17 00:00:00 2001 From: Will Bradley Date: Mon, 13 Jul 2020 22:48:05 -0600 Subject: [PATCH] wip --- src/ast.cpp | 8 ++---- src/data_ctors_map.cpp | 1 - src/gen.cpp | 4 +++ src/infer.cpp | 18 ++++++------- src/main.cpp | 2 +- src/match.cpp | 13 ++++----- src/parser.cpp | 14 +++++++--- src/patterns.cpp | 21 ++++++++++----- src/translate.cpp | 13 ++++++++- src/types.cpp | 60 +++++++++++++++++------------------------- src/types.h | 10 +++---- src/utils.h | 4 +++ 12 files changed, 93 insertions(+), 75 deletions(-) diff --git a/src/ast.cpp b/src/ast.cpp index 3a087eb1..986e72f8 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -252,7 +252,7 @@ Location Tuple::get_location() const { std::ostream &Tuple::render(std::ostream &os, int parent_precedence) const { os << "("; for (auto dim : dims) { - dim->render(os, 0); + dim->render(os, 10); os << ","; } return os << ")"; @@ -288,11 +288,7 @@ Location FFI::get_location() const { } std::ostream &FFI::render(std::ostream &os, int parent_precedence) const { - os << C_WARN << "ffi(" << escape_json_quotes(id.name); - for (auto expr : exprs) { - os << ", " << expr->str(); - } - return os << ")"; + return os << C_WARN << "ffi " << id.str() << "(" << join_str(exprs) << ")"; } Location Conditional::get_location() const { diff --git a/src/data_ctors_map.cpp b/src/data_ctors_map.cpp index 7ecd7ee5..80bc251c 100644 --- a/src/data_ctors_map.cpp +++ b/src/data_ctors_map.cpp @@ -49,7 +49,6 @@ types::Ref get_data_ctor_type(const DataCtorsMap &data_ctors_map, std::map get_data_ctors_types( const DataCtorsMap &data_ctors_map, types::Ref type) { - std::cerr << "unfolding " << type->str() << std::endl; types::Refs type_terms; unfold_ops_lassoc(type, type_terms); assert(type_terms.size() != 0); diff --git a/src/gen.cpp b/src/gen.cpp index 9f779e58..a5f4b3f5 100644 --- a/src/gen.cpp +++ b/src/gen.cpp @@ -184,6 +184,10 @@ void get_free_vars(const ast::Expr *expr, } else if (auto as = dcast(expr)) { get_free_vars(as->expr, typing, globals, locals, free_vars); } else if (dcast(expr)) { + } else if (auto ffi = dcast(expr)) { + for (auto expr : ffi->exprs) { + get_free_vars(expr, typing, globals, locals, free_vars); + } } else if (auto builtin = dcast(expr)) { for (auto expr : builtin->exprs) { get_free_vars(expr, typing, globals, locals, free_vars); diff --git a/src/infer.cpp b/src/infer.cpp index 103061e7..876e6509 100644 --- a/src/infer.cpp +++ b/src/infer.cpp @@ -67,14 +67,14 @@ types::Ref infer_core(const Expr *expr, "return type does not match type annotation :: %s", lambda->return_type->str().c_str())); } - return type_arrow(tv, local_return_type); + return type_arrow(type_params({tv}), local_return_type); } else if (auto application = dcast(expr)) { auto t1 = infer(application->a, data_ctors_map, return_type, scheme_resolver, tracked_types, constraints, instance_requirements); - auto t2 = infer(application->b, data_ctors_map, return_type, - scheme_resolver, tracked_types, constraints, - instance_requirements); + auto t2 = type_params( + {infer(application->b, data_ctors_map, return_type, scheme_resolver, + tracked_types, constraints, instance_requirements)}); auto tv = type_variable(expr->get_location()); append_to_constraints( constraints, t1, type_arrow(application->get_location(), t2, tv), @@ -131,9 +131,9 @@ types::Ref infer_core(const Expr *expr, make_context(defer->get_location(), "defer must call nullary function")); - auto t2 = infer(defer->application->b, data_ctors_map, return_type, - scheme_resolver, tracked_types, constraints, - instance_requirements); + auto t2 = type_params({infer(defer->application->b, data_ctors_map, + return_type, scheme_resolver, tracked_types, + constraints, instance_requirements)}); append_to_constraints( constraints, t1, type_arrow(defer->application->get_location(), t2, @@ -374,8 +374,8 @@ types::Ref CtorPredicate::tracking_infer( ctor_type->str().c_str())); types::Refs outer_ctor_terms = unfold_arrows(ctor_type); - types::Refs ctor_terms = get_ctor_terms(get_location(), ctor_name.str(), - outer_ctor_terms, params.size()); + types::Refs ctor_terms = get_ctor_param_terms( + get_location(), ctor_name.str(), outer_ctor_terms, params.size()); types::Ref result_type; for (size_t i = 0; i < params.size(); ++i) { diff --git a/src/main.cpp b/src/main.cpp index 2aec31c7..067de014 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -417,7 +417,7 @@ CheckedDefinitionsByName check_decls(std::string entry_point_name, append_to_constraints( constraints, ty, type_arrow(INTERNAL_LOC(), - type_unit(INTERNAL_LOC()), + type_params({type_unit(INTERNAL_LOC())}), type_unit(INTERNAL_LOC())), make_context(INTERNAL_LOC(), "main function must have signature fn () ()")); diff --git a/src/match.cpp b/src/match.cpp index d0bbba80..a25a7f73 100644 --- a/src/match.cpp +++ b/src/match.cpp @@ -399,8 +399,8 @@ Pattern::ref from_type(Location location, return allFloats; } else if (unify(type, type_ptr(type_variable(location))).result) { return all_of(location, {}, data_ctors_map, type); - } else if (unify(type, - type_arrow(type_variable(location), type_variable(location))) + } else if (unify(type, type_arrow(type_params({type_variable(location)}), + type_variable(location))) .result) { return all_of(location, {}, data_ctors_map, type); } else { @@ -410,7 +410,7 @@ Pattern::ref from_type(Location location, for (auto pair : ctors_types) { auto &ctor_name = pair.first; - auto ctor_terms = get_ctor_terms(unfold_arrows(pair.second)); + auto ctor_terms = get_ctor_param_terms(unfold_arrows(pair.second)); std::vector args; if (ctor_terms.size() != 0) { @@ -732,7 +732,7 @@ Pattern::ref CtorPredicate::get_pattern( auto outer_ctor_terms = unfold_arrows( get_data_ctor_type(data_ctors_map, type, ctor_name)); - types::Refs ctor_terms = get_ctor_terms(get_location(), ctor_name.str(), + types::Refs ctor_terms = get_ctor_param_terms(get_location(), ctor_name.str(), outer_ctor_terms, params.size()); std::vector args; @@ -740,8 +740,9 @@ Pattern::ref CtorPredicate::get_pattern( args.push_back(params[i]->get_pattern(ctor_terms[i], data_ctors_map)); } - std::cerr << "Pattern for CtorPredicate for " << type->str() << " has args " - << join_str(args) << std::endl; + debug_above(4, log("pattern for CtorPredicate for %s has args %s", + type->str().c_str(), join_str(args).c_str())); + /* found the ctor we're matching on */ return std::make_shared( location, CtorPatternValue{type->repr(), ctor_name.name, args}); diff --git a/src/parser.cpp b/src/parser.cpp index aa091965..b0be49f0 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -1740,6 +1740,7 @@ const Predicate *parse_ctor_predicate(ParseState &ps, Identifier ctor_name = ps.identifier_and_advance(); std::vector params; + auto location = ps.token.location; if (ps.token.tk == tk_lparen) { ps.advance(); while (ps.token.tk != tk_rparen) { @@ -1751,8 +1752,15 @@ const Predicate *parse_ctor_predicate(ParseState &ps, } chomp_token(tk_rparen); } - return new CtorPredicate(ctor_name.location, params, ctor_name, - name_assignment); + if (params.size() <= 1) { + return new CtorPredicate(ctor_name.location, params, ctor_name, + name_assignment); + } else { + return new CtorPredicate( + ctor_name.location, + {new TuplePredicate(location, params, maybe())}, ctor_name, + name_assignment); + } } const Predicate *parse_tuple_predicate(ParseState &ps, @@ -2018,8 +2026,6 @@ const Expr *parse_lambda(ParseState &ps, } auto predicate = new TuplePredicate(ps.token.location, params, maybe()); - // TODO: convert the lambda params to a tuple if they are >=2 and - // destructure their names into the given ids. return new Lambda( param_id, param_type, return_type, new Match(new Var(param_id), diff --git a/src/patterns.cpp b/src/patterns.cpp index 868f662d..d63b8626 100644 --- a/src/patterns.cpp +++ b/src/patterns.cpp @@ -196,7 +196,8 @@ const Expr *Literal::translate( types::Ref type = get_tracked_type(tracked_types, this); auto Bool = type_id(make_iid(BOOL_TYPE)); Var *literal_cmp = new Var(make_iid("std.==")); - types::Ref cmp_type = type_arrow(type_tuple({type, type}), Bool); + types::Ref cmp_type = type_arrow(type_params({type_tuple({type, type})}), + Bool); typing[literal_cmp] = cmp_type; insert_needed_defn(needed_defns, types::DefnId{literal_cmp->id, cmp_type}, @@ -292,6 +293,9 @@ const Expr *translate_next(const types::DefnId &for_defn_id, 0 /*ignored in gen phase*/); typing[dim] = param_types[param_index]; + assert(params.size() > param_index); + assert(param_types.size() > param_index); + auto body = params[param_index]->translate( for_defn_id, param_id, param_types[param_index], do_checks, data_ctors_map, bound_vars, tracked_types, type_env, typing, needed_defns, @@ -331,10 +335,15 @@ const Expr *CtorPredicate::translate( static auto Bool = type_bool(INTERNAL_LOC()); types::Ref ctor_type = get_data_ctor_type(data_ctors_map, scrutinee_type, ctor_name); - types::Refs ctor_terms = unfold_arrows(ctor_type); + types::Refs ctor_terms = get_ctor_param_terms(unfold_arrows(ctor_type)); - assert(ctor_terms.size() >= 1); - ctor_terms = vec_slice(ctor_terms, 0, int(ctor_terms.size()) - 1); + debug_above(2, + log_location(get_location(), + "in ctor %s scrutinee type %s has terms %s", + ctor_name.str().c_str(), scrutinee_type->str().c_str(), + ::str(ctor_terms).c_str())); + // assert(ctor_terms.size() >= 1); + // ctor_terms = vec_slice(ctor_terms, 0, int(ctor_terms.size()) - 1); types::Ref resolved_scrutinee_type = scrutinee_type->eval(type_env, true /*shallow*/); @@ -422,8 +431,8 @@ const Expr *CtorPredicate::translate( typing[condition] = type_bool(INTERNAL_LOC()); } else { Var *cmp_ctor_id = new Var(make_iid("__builtin_cmp_ctor_id")); - typing[cmp_ctor_id] = type_arrow(type_tuple({scrutinee_type, Int}), - Bool); + typing[cmp_ctor_id] = type_arrow( + type_params({type_tuple({scrutinee_type, Int})}), Bool); condition = new Builtin(cmp_ctor_id, {scrutinee, ctor_id_literal}); typing[condition] = type_bool(INTERNAL_LOC()); diff --git a/src/translate.cpp b/src/translate.cpp index e524b9b7..678b844c 100644 --- a/src/translate.cpp +++ b/src/translate.cpp @@ -150,7 +150,7 @@ const Expr *texpr(const types::DefnId &for_defn_id, types::Refs terms = unfold_arrows(operator_type); assert(terms.size() > 1); - types::Ref resolution_type = type_arrow(operand_type, type); + types::Ref resolution_type = type_arrow(type_params({operand_type}), type); types::Unification unification = unify(operator_type, resolution_type); assert(unification.result); operator_type = operator_type->rebind(unification.bindings); @@ -278,6 +278,17 @@ const Expr *texpr(const types::DefnId &for_defn_id, assert(typing.count(expr)); return expr; } + } else if (auto ffi = dcast(expr)) { + std::vector exprs; + for (auto expr : ffi->exprs) { + exprs.push_back(texpr(for_defn_id, expr, data_ctors_map, bound_vars, + tracked_types, + get_tracked_type(tracked_types, expr), type_env, + typing, needed_defns, returns)); + } + auto new_ffi = new FFI(ffi->id, exprs); + typing[new_ffi] = type; + return new_ffi; } else if (auto builtin = dcast(expr)) { std::vector exprs; for (auto expr : builtin->exprs) { diff --git a/src/types.cpp b/src/types.cpp index 3e0b3b08..87120229 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -841,9 +841,9 @@ types::Ref type_arrows(types::Refs types) { auto return_type = types.back(); types.resize(types.size() - 1); if (types.size() == 1) { - return type_arrow(types[0], return_type); + return type_arrow(type_params({types[0]}), return_type); } else { - return type_arrow(type_tuple(types), return_type); + return type_arrow(type_params({type_tuple(types)}), return_type); } } @@ -853,14 +853,11 @@ types::Refs unfold_arrows(types::Ref type) { auto nested_op = dyncast(op->oper); if (nested_op != nullptr) { if (is_type_id(nested_op->oper, ARROW_TYPE_OPERATOR)) { - if (auto type_params = dyncast( - nested_op->operand)) { - types::Refs terms = type_params->dimensions; - terms.push_back(op->operand); - return terms; - } else { - return {nested_op->operand, op->operand}; - } + auto type_params = safe_dyncast( + nested_op->operand); + types::Refs terms = type_params->dimensions; + terms.push_back(op->operand); + return terms; } } } @@ -883,7 +880,7 @@ types::Ref type_tuple_accessor(int i, for (int j = 0; j < max; ++j) { dims.push_back(type_variable(make_iid(vars[j]))); } - return type_arrow(type_tuple(dims), + return type_arrow(type_params({type_tuple(dims)}), type_variable(make_iid(vars[i]))); } @@ -1002,39 +999,30 @@ types::Ref tuple_deref_type(Location location, return tuple->dimensions[index]; } -types::Refs get_ctor_terms(Location location, - std::string ctor_name, - const types::Refs &outer_ctor_terms, - int params_count) { +types::Refs get_ctor_param_terms(Location location, + std::string ctor_name, + const types::Refs &outer_ctor_terms, + int params_count) { assert(outer_ctor_terms.size() >= 1); - types::Refs ctor_terms = get_ctor_terms(outer_ctor_terms); + types::Refs ctor_param_terms = get_ctor_param_terms(outer_ctor_terms); - if (ctor_terms.size() != params_count) { - throw zion::user_error( - location, "incorrect number of sub-patterns given to %s (%d vs. %d)", - ctor_name.c_str(), ctor_terms.size(), params_count); + if (ctor_param_terms.size() != params_count) { + throw zion::user_error(location, + "incorrect number of sub-patterns given to %s (%d " + "vs. %d) (outer terms = %s)", + ctor_name.c_str(), ctor_param_terms.size(), + params_count, str(outer_ctor_terms).c_str()); } - return ctor_terms; + return ctor_param_terms; } -types::Refs get_ctor_terms(const types::Refs &outer_ctor_terms) { - assert(outer_ctor_terms.size() >= 1); - - types::Refs ctor_terms; - +types::Refs get_ctor_param_terms(const types::Refs &outer_ctor_terms) { if (outer_ctor_terms.size() > 1) { assert(outer_ctor_terms.size() == 2); - if (auto ctor_tuple = dyncast( - outer_ctor_terms[0])) { - ctor_terms = ctor_tuple->dimensions; - } else { - // Handle single parameter data ctors. This only works because we have - // banished unary tuples. - ctor_terms.push_back(outer_ctor_terms[0]); - } + return {outer_ctor_terms.front()}; + } else { + return {}; } - - return ctor_terms; } diff --git a/src/types.h b/src/types.h index 0ee9d1ea..87e7c2e8 100644 --- a/src/types.h +++ b/src/types.h @@ -280,8 +280,8 @@ std::ostream &join_dimensions(std::ostream &os, const types::Map &bindings); std::string get_name_from_index(const types::NameIndex &name_index, int i); bool is_valid_udt_initial_char(int ch); -types::Refs get_ctor_terms(const types::Refs &outer_ctor_terms); -types::Refs get_ctor_terms(Location location, - std::string ctor_name, - const types::Refs &outer_ctor_terms, - int params_count); +types::Refs get_ctor_param_terms(const types::Refs &outer_ctor_terms); +types::Refs get_ctor_param_terms(Location location, + std::string ctor_name, + const types::Refs &outer_ctor_terms, + int params_count); diff --git a/src/utils.h b/src/utils.h index 832b7b1c..cc9cfffa 100644 --- a/src/utils.h +++ b/src/utils.h @@ -348,6 +348,10 @@ bool any_in(const C1 &needles, const C2 &haystack) { template std::vector vec_slice(const std::vector &orig, int start, int lim) { + assert(start <= orig.size()); + assert(lim <= orig.size()); + assert(lim >= start); + std::vector output; output.reserve(lim - start); for (int i = start; i < lim; ++i) {