Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Bradley committed Jul 14, 2020
1 parent 96ff650 commit 79544cc
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 75 deletions.
8 changes: 2 additions & 6 deletions src/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ Location Tuple::get_location() const {
std::ostream &Tuple::render(std::ostream &os, int parent_precedence) const {
os << "(";
for (auto dim : dims) {
dim->render(os, 0);
dim->render(os, 10);
os << ",";
}
return os << ")";
Expand Down Expand Up @@ -288,11 +288,7 @@ Location FFI::get_location() const {
}

std::ostream &FFI::render(std::ostream &os, int parent_precedence) const {
os << C_WARN << "ffi(" << escape_json_quotes(id.name);
for (auto expr : exprs) {
os << ", " << expr->str();
}
return os << ")";
return os << C_WARN << "ffi " << id.str() << "(" << join_str(exprs) << ")";
}

Location Conditional::get_location() const {
Expand Down
1 change: 0 additions & 1 deletion src/data_ctors_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ types::Ref get_data_ctor_type(const DataCtorsMap &data_ctors_map,
std::map<std::string, types::Ref> get_data_ctors_types(
const DataCtorsMap &data_ctors_map,
types::Ref type) {
std::cerr << "unfolding " << type->str() << std::endl;
types::Refs type_terms;
unfold_ops_lassoc(type, type_terms);
assert(type_terms.size() != 0);
Expand Down
4 changes: 4 additions & 0 deletions src/gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ void get_free_vars(const ast::Expr *expr,
} else if (auto as = dcast<const ast::As *>(expr)) {
get_free_vars(as->expr, typing, globals, locals, free_vars);
} else if (dcast<const ast::Sizeof *>(expr)) {
} else if (auto ffi = dcast<const ast::FFI *>(expr)) {
for (auto expr : ffi->exprs) {
get_free_vars(expr, typing, globals, locals, free_vars);
}
} else if (auto builtin = dcast<const ast::Builtin *>(expr)) {
for (auto expr : builtin->exprs) {
get_free_vars(expr, typing, globals, locals, free_vars);
Expand Down
18 changes: 9 additions & 9 deletions src/infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ types::Ref infer_core(const Expr *expr,
"return type does not match type annotation :: %s",
lambda->return_type->str().c_str()));
}
return type_arrow(tv, local_return_type);
return type_arrow(type_params({tv}), local_return_type);
} else if (auto application = dcast<const Application *>(expr)) {
auto t1 = infer(application->a, data_ctors_map, return_type,
scheme_resolver, tracked_types, constraints,
instance_requirements);
auto t2 = infer(application->b, data_ctors_map, return_type,
scheme_resolver, tracked_types, constraints,
instance_requirements);
auto t2 = type_params(
{infer(application->b, data_ctors_map, return_type, scheme_resolver,
tracked_types, constraints, instance_requirements)});
auto tv = type_variable(expr->get_location());
append_to_constraints(
constraints, t1, type_arrow(application->get_location(), t2, tv),
Expand Down Expand Up @@ -131,9 +131,9 @@ types::Ref infer_core(const Expr *expr,
make_context(defer->get_location(),
"defer must call nullary function"));

auto t2 = infer(defer->application->b, data_ctors_map, return_type,
scheme_resolver, tracked_types, constraints,
instance_requirements);
auto t2 = type_params({infer(defer->application->b, data_ctors_map,
return_type, scheme_resolver, tracked_types,
constraints, instance_requirements)});
append_to_constraints(
constraints, t1,
type_arrow(defer->application->get_location(), t2,
Expand Down Expand Up @@ -374,8 +374,8 @@ types::Ref CtorPredicate::tracking_infer(
ctor_type->str().c_str()));

types::Refs outer_ctor_terms = unfold_arrows(ctor_type);
types::Refs ctor_terms = get_ctor_terms(get_location(), ctor_name.str(),
outer_ctor_terms, params.size());
types::Refs ctor_terms = get_ctor_param_terms(
get_location(), ctor_name.str(), outer_ctor_terms, params.size());

types::Ref result_type;
for (size_t i = 0; i < params.size(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ CheckedDefinitionsByName check_decls(std::string entry_point_name,
append_to_constraints(
constraints, ty,
type_arrow(INTERNAL_LOC(),
type_unit(INTERNAL_LOC()),
type_params({type_unit(INTERNAL_LOC())}),
type_unit(INTERNAL_LOC())),
make_context(INTERNAL_LOC(),
"main function must have signature fn () ()"));
Expand Down
13 changes: 7 additions & 6 deletions src/match.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,8 @@ Pattern::ref from_type(Location location,
return allFloats;
} else if (unify(type, type_ptr(type_variable(location))).result) {
return all_of(location, {}, data_ctors_map, type);
} else if (unify(type,
type_arrow(type_variable(location), type_variable(location)))
} else if (unify(type, type_arrow(type_params({type_variable(location)}),
type_variable(location)))
.result) {
return all_of(location, {}, data_ctors_map, type);
} else {
Expand All @@ -410,7 +410,7 @@ Pattern::ref from_type(Location location,

for (auto pair : ctors_types) {
auto &ctor_name = pair.first;
auto ctor_terms = get_ctor_terms(unfold_arrows(pair.second));
auto ctor_terms = get_ctor_param_terms(unfold_arrows(pair.second));

std::vector<Pattern::ref> args;
if (ctor_terms.size() != 0) {
Expand Down Expand Up @@ -732,16 +732,17 @@ Pattern::ref CtorPredicate::get_pattern(
auto outer_ctor_terms = unfold_arrows(
get_data_ctor_type(data_ctors_map, type, ctor_name));

types::Refs ctor_terms = get_ctor_terms(get_location(), ctor_name.str(),
types::Refs ctor_terms = get_ctor_param_terms(get_location(), ctor_name.str(),
outer_ctor_terms, params.size());

std::vector<Pattern::ref> args;
for (size_t i = 0; i < params.size(); ++i) {
args.push_back(params[i]->get_pattern(ctor_terms[i], data_ctors_map));
}

std::cerr << "Pattern for CtorPredicate for " << type->str() << " has args "
<< join_str(args) << std::endl;
debug_above(4, log("pattern for CtorPredicate for %s has args %s",
type->str().c_str(), join_str(args).c_str()));

/* found the ctor we're matching on */
return std::make_shared<CtorPattern>(
location, CtorPatternValue{type->repr(), ctor_name.name, args});
Expand Down
14 changes: 10 additions & 4 deletions src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,7 @@ const Predicate *parse_ctor_predicate(ParseState &ps,
Identifier ctor_name = ps.identifier_and_advance();

std::vector<const Predicate *> params;
auto location = ps.token.location;
if (ps.token.tk == tk_lparen) {
ps.advance();
while (ps.token.tk != tk_rparen) {
Expand All @@ -1751,8 +1752,15 @@ const Predicate *parse_ctor_predicate(ParseState &ps,
}
chomp_token(tk_rparen);
}
return new CtorPredicate(ctor_name.location, params, ctor_name,
name_assignment);
if (params.size() <= 1) {
return new CtorPredicate(ctor_name.location, params, ctor_name,
name_assignment);
} else {
return new CtorPredicate(
ctor_name.location,
{new TuplePredicate(location, params, maybe<Identifier>())}, ctor_name,
name_assignment);
}
}

const Predicate *parse_tuple_predicate(ParseState &ps,
Expand Down Expand Up @@ -2018,8 +2026,6 @@ const Expr *parse_lambda(ParseState &ps,
}
auto predicate = new TuplePredicate(ps.token.location, params,
maybe<Identifier>());
// TODO: convert the lambda params to a tuple if they are >=2 and
// destructure their names into the given ids.
return new Lambda(
param_id, param_type, return_type,
new Match(new Var(param_id),
Expand Down
21 changes: 15 additions & 6 deletions src/patterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ const Expr *Literal::translate(
types::Ref type = get_tracked_type(tracked_types, this);
auto Bool = type_id(make_iid(BOOL_TYPE));
Var *literal_cmp = new Var(make_iid("std.=="));
types::Ref cmp_type = type_arrow(type_tuple({type, type}), Bool);
types::Ref cmp_type = type_arrow(type_params({type_tuple({type, type})}),
Bool);

typing[literal_cmp] = cmp_type;
insert_needed_defn(needed_defns, types::DefnId{literal_cmp->id, cmp_type},
Expand Down Expand Up @@ -292,6 +293,9 @@ const Expr *translate_next(const types::DefnId &for_defn_id,
0 /*ignored in gen phase*/);
typing[dim] = param_types[param_index];

assert(params.size() > param_index);
assert(param_types.size() > param_index);

auto body = params[param_index]->translate(
for_defn_id, param_id, param_types[param_index], do_checks,
data_ctors_map, bound_vars, tracked_types, type_env, typing, needed_defns,
Expand Down Expand Up @@ -331,10 +335,15 @@ const Expr *CtorPredicate::translate(
static auto Bool = type_bool(INTERNAL_LOC());
types::Ref ctor_type = get_data_ctor_type(data_ctors_map, scrutinee_type,
ctor_name);
types::Refs ctor_terms = unfold_arrows(ctor_type);
types::Refs ctor_terms = get_ctor_param_terms(unfold_arrows(ctor_type));

assert(ctor_terms.size() >= 1);
ctor_terms = vec_slice(ctor_terms, 0, int(ctor_terms.size()) - 1);
debug_above(2,
log_location(get_location(),
"in ctor %s scrutinee type %s has terms %s",
ctor_name.str().c_str(), scrutinee_type->str().c_str(),
::str(ctor_terms).c_str()));
// assert(ctor_terms.size() >= 1);
// ctor_terms = vec_slice(ctor_terms, 0, int(ctor_terms.size()) - 1);

types::Ref resolved_scrutinee_type = scrutinee_type->eval(type_env,
true /*shallow*/);
Expand Down Expand Up @@ -422,8 +431,8 @@ const Expr *CtorPredicate::translate(
typing[condition] = type_bool(INTERNAL_LOC());
} else {
Var *cmp_ctor_id = new Var(make_iid("__builtin_cmp_ctor_id"));
typing[cmp_ctor_id] = type_arrow(type_tuple({scrutinee_type, Int}),
Bool);
typing[cmp_ctor_id] = type_arrow(
type_params({type_tuple({scrutinee_type, Int})}), Bool);

condition = new Builtin(cmp_ctor_id, {scrutinee, ctor_id_literal});
typing[condition] = type_bool(INTERNAL_LOC());
Expand Down
13 changes: 12 additions & 1 deletion src/translate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ const Expr *texpr(const types::DefnId &for_defn_id,
types::Refs terms = unfold_arrows(operator_type);
assert(terms.size() > 1);

types::Ref resolution_type = type_arrow(operand_type, type);
types::Ref resolution_type = type_arrow(type_params({operand_type}), type);
types::Unification unification = unify(operator_type, resolution_type);
assert(unification.result);
operator_type = operator_type->rebind(unification.bindings);
Expand Down Expand Up @@ -278,6 +278,17 @@ const Expr *texpr(const types::DefnId &for_defn_id,
assert(typing.count(expr));
return expr;
}
} else if (auto ffi = dcast<const FFI *>(expr)) {
std::vector<const Expr *> exprs;
for (auto expr : ffi->exprs) {
exprs.push_back(texpr(for_defn_id, expr, data_ctors_map, bound_vars,
tracked_types,
get_tracked_type(tracked_types, expr), type_env,
typing, needed_defns, returns));
}
auto new_ffi = new FFI(ffi->id, exprs);
typing[new_ffi] = type;
return new_ffi;
} else if (auto builtin = dcast<const Builtin *>(expr)) {
std::vector<const Expr *> exprs;
for (auto expr : builtin->exprs) {
Expand Down
60 changes: 24 additions & 36 deletions src/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,9 @@ types::Ref type_arrows(types::Refs types) {
auto return_type = types.back();
types.resize(types.size() - 1);
if (types.size() == 1) {
return type_arrow(types[0], return_type);
return type_arrow(type_params({types[0]}), return_type);
} else {
return type_arrow(type_tuple(types), return_type);
return type_arrow(type_params({type_tuple(types)}), return_type);
}
}

Expand All @@ -853,14 +853,11 @@ types::Refs unfold_arrows(types::Ref type) {
auto nested_op = dyncast<const types::TypeOperator>(op->oper);
if (nested_op != nullptr) {
if (is_type_id(nested_op->oper, ARROW_TYPE_OPERATOR)) {
if (auto type_params = dyncast<const types::TypeParams>(
nested_op->operand)) {
types::Refs terms = type_params->dimensions;
terms.push_back(op->operand);
return terms;
} else {
return {nested_op->operand, op->operand};
}
auto type_params = safe_dyncast<const types::TypeParams>(
nested_op->operand);
types::Refs terms = type_params->dimensions;
terms.push_back(op->operand);
return terms;
}
}
}
Expand All @@ -883,7 +880,7 @@ types::Ref type_tuple_accessor(int i,
for (int j = 0; j < max; ++j) {
dims.push_back(type_variable(make_iid(vars[j])));
}
return type_arrow(type_tuple(dims),
return type_arrow(type_params({type_tuple(dims)}),
type_variable(make_iid(vars[i])));
}

Expand Down Expand Up @@ -1002,39 +999,30 @@ types::Ref tuple_deref_type(Location location,
return tuple->dimensions[index];
}

types::Refs get_ctor_terms(Location location,
std::string ctor_name,
const types::Refs &outer_ctor_terms,
int params_count) {
types::Refs get_ctor_param_terms(Location location,
std::string ctor_name,
const types::Refs &outer_ctor_terms,
int params_count) {
assert(outer_ctor_terms.size() >= 1);

types::Refs ctor_terms = get_ctor_terms(outer_ctor_terms);
types::Refs ctor_param_terms = get_ctor_param_terms(outer_ctor_terms);

if (ctor_terms.size() != params_count) {
throw zion::user_error(
location, "incorrect number of sub-patterns given to %s (%d vs. %d)",
ctor_name.c_str(), ctor_terms.size(), params_count);
if (ctor_param_terms.size() != params_count) {
throw zion::user_error(location,
"incorrect number of sub-patterns given to %s (%d "
"vs. %d) (outer terms = %s)",
ctor_name.c_str(), ctor_param_terms.size(),
params_count, str(outer_ctor_terms).c_str());
}

return ctor_terms;
return ctor_param_terms;
}

types::Refs get_ctor_terms(const types::Refs &outer_ctor_terms) {
assert(outer_ctor_terms.size() >= 1);

types::Refs ctor_terms;

types::Refs get_ctor_param_terms(const types::Refs &outer_ctor_terms) {
if (outer_ctor_terms.size() > 1) {
assert(outer_ctor_terms.size() == 2);
if (auto ctor_tuple = dyncast<const types::TypeTuple>(
outer_ctor_terms[0])) {
ctor_terms = ctor_tuple->dimensions;
} else {
// Handle single parameter data ctors. This only works because we have
// banished unary tuples.
ctor_terms.push_back(outer_ctor_terms[0]);
}
return {outer_ctor_terms.front()};
} else {
return {};
}

return ctor_terms;
}
10 changes: 5 additions & 5 deletions src/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ std::ostream &join_dimensions(std::ostream &os,
const types::Map &bindings);
std::string get_name_from_index(const types::NameIndex &name_index, int i);
bool is_valid_udt_initial_char(int ch);
types::Refs get_ctor_terms(const types::Refs &outer_ctor_terms);
types::Refs get_ctor_terms(Location location,
std::string ctor_name,
const types::Refs &outer_ctor_terms,
int params_count);
types::Refs get_ctor_param_terms(const types::Refs &outer_ctor_terms);
types::Refs get_ctor_param_terms(Location location,
std::string ctor_name,
const types::Refs &outer_ctor_terms,
int params_count);
4 changes: 4 additions & 0 deletions src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,10 @@ bool any_in(const C1 &needles, const C2 &haystack) {

template <typename V>
std::vector<V> vec_slice(const std::vector<V> &orig, int start, int lim) {
assert(start <= orig.size());
assert(lim <= orig.size());
assert(lim >= start);

std::vector<V> output;
output.reserve(lim - start);
for (int i = start; i < lim; ++i) {
Expand Down

0 comments on commit 79544cc

Please sign in to comment.