aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/check_stmt.cpp8
-rw-r--r--src/llvm_backend_stmt.cpp19
-rw-r--r--src/parser.cpp29
-rw-r--r--src/parser.hpp2
4 files changed, 51 insertions, 7 deletions
diff --git a/src/check_stmt.cpp b/src/check_stmt.cpp
index bae95b7c7..0701ff9d8 100644
--- a/src/check_stmt.cpp
+++ b/src/check_stmt.cpp
@@ -1705,11 +1705,17 @@ gb_internal void check_range_stmt(CheckerContext *ctx, Ast *node, u32 mod_flags)
TEMPORARY_ALLOCATOR_GUARD();
- u32 new_flags = mod_flags | Stmt_BreakAllowed | Stmt_ContinueAllowed;
check_open_scope(ctx, node);
check_label(ctx, rs->label, node);
+ Operand init = {};
+ if (rs->init != nullptr) {
+ check_stmt(ctx, rs->init, mod_flags);
+ }
+
+ u32 new_flags = mod_flags | Stmt_BreakAllowed | Stmt_ContinueAllowed;
+
auto vals = array_make<Type *>(temporary_allocator(), 0, 2);
auto entities = array_make<Entity *>(temporary_allocator(), 0, 2);
bool is_map = false;
diff --git a/src/llvm_backend_stmt.cpp b/src/llvm_backend_stmt.cpp
index 05ec10cda..98b45f646 100644
--- a/src/llvm_backend_stmt.cpp
+++ b/src/llvm_backend_stmt.cpp
@@ -756,6 +756,10 @@ gb_internal void lb_build_range_interval(lbProcedure *p, AstBinaryExpr *node,
lb_open_scope(p, scope);
+ if (rs->init != nullptr) {
+ lb_build_stmt(p, rs->init);
+ }
+
Ast *val0 = rs->vals.count > 0 ? lb_strip_and_prefix(rs->vals[0]) : nullptr;
Ast *val1 = rs->vals.count > 1 ? lb_strip_and_prefix(rs->vals[1]) : nullptr;
Type *val0_type = nullptr;
@@ -948,6 +952,10 @@ gb_internal void lb_build_range_tuple(lbProcedure *p, AstRangeStmt *rs, Scope *s
lb_open_scope(p, scope);
+ if (rs->init != nullptr) {
+ lb_build_stmt(p, rs->init);
+ }
+
lbBlock *loop = lb_create_block(p, "for.tuple.loop");
lb_emit_jump(p, loop);
lb_start_block(p, loop);
@@ -1002,6 +1010,9 @@ gb_internal void lb_build_range_stmt_struct_soa(lbProcedure *p, AstRangeStmt *rs
lb_open_scope(p, scope);
+ if (rs->init != nullptr) {
+ lb_build_stmt(p, rs->init);
+ }
Ast *val0 = rs->vals.count > 0 ? lb_strip_and_prefix(rs->vals[0]) : nullptr;
Ast *val1 = rs->vals.count > 1 ? lb_strip_and_prefix(rs->vals[1]) : nullptr;
@@ -1153,6 +1164,10 @@ gb_internal void lb_build_range_stmt(lbProcedure *p, AstRangeStmt *rs, Scope *sc
lb_open_scope(p, scope);
+ if (rs->init != nullptr) {
+ lb_build_stmt(p, rs->init);
+ }
+
Ast *val0 = rs->vals.count > 0 ? lb_strip_and_prefix(rs->vals[0]) : nullptr;
Ast *val1 = rs->vals.count > 1 ? lb_strip_and_prefix(rs->vals[1]) : nullptr;
Type *val0_type = nullptr;
@@ -1352,6 +1367,10 @@ gb_internal void lb_build_unroll_range_stmt(lbProcedure *p, AstUnrollRangeStmt *
lb_open_scope(p, scope); // Open scope here
+ if (rs->init != nullptr) {
+ lb_build_stmt(p, rs->init);
+ }
+
Ast *val0 = lb_strip_and_prefix(rs->val0);
Ast *val1 = lb_strip_and_prefix(rs->val1);
Type *val0_type = nullptr;
diff --git a/src/parser.cpp b/src/parser.cpp
index 159eb65f8..42c69cd4c 100644
--- a/src/parser.cpp
+++ b/src/parser.cpp
@@ -353,12 +353,14 @@ gb_internal Ast *clone_ast(Ast *node, AstFile *f) {
break;
case Ast_RangeStmt:
n->RangeStmt.label = clone_ast(n->RangeStmt.label, f);
+ n->RangeStmt.init = clone_ast(n->RangeStmt.init, f);
n->RangeStmt.vals = clone_ast_array(n->RangeStmt.vals, f);
n->RangeStmt.expr = clone_ast(n->RangeStmt.expr, f);
n->RangeStmt.body = clone_ast(n->RangeStmt.body, f);
break;
case Ast_UnrollRangeStmt:
n->UnrollRangeStmt.args = clone_ast_array(n->UnrollRangeStmt.args, f);
+ n->UnrollRangeStmt.init = clone_ast(n->UnrollRangeStmt.init, f);
n->UnrollRangeStmt.val0 = clone_ast(n->UnrollRangeStmt.val0, f);
n->UnrollRangeStmt.val1 = clone_ast(n->UnrollRangeStmt.val1, f);
n->UnrollRangeStmt.expr = clone_ast(n->UnrollRangeStmt.expr, f);
@@ -1055,9 +1057,10 @@ gb_internal Ast *ast_for_stmt(AstFile *f, Token token, Ast *init, Ast *cond, Ast
return result;
}
-gb_internal Ast *ast_range_stmt(AstFile *f, Token token, Slice<Ast *> vals, Token in_token, Ast *expr, Ast *body) {
+gb_internal Ast *ast_range_stmt(AstFile *f, Token token, Ast *init, Slice<Ast *> vals, Token in_token, Ast *expr, Ast *body) {
Ast *result = alloc_ast_node(f, Ast_RangeStmt);
result->RangeStmt.token = token;
+ result->RangeStmt.init = init;
result->RangeStmt.vals = vals;
result->RangeStmt.in_token = in_token;
result->RangeStmt.expr = expr;
@@ -1065,9 +1068,10 @@ gb_internal Ast *ast_range_stmt(AstFile *f, Token token, Slice<Ast *> vals, Toke
return result;
}
-gb_internal Ast *ast_unroll_range_stmt(AstFile *f, Token unroll_token, Slice<Ast *> args, Token for_token, Ast *val0, Ast *val1, Token in_token, Ast *expr, Ast *body) {
+gb_internal Ast *ast_unroll_range_stmt(AstFile *f, Token unroll_token, Ast *init, Slice<Ast *> args, Token for_token, Ast *val0, Ast *val1, Token in_token, Ast *expr, Ast *body) {
Ast *result = alloc_ast_node(f, Ast_UnrollRangeStmt);
result->UnrollRangeStmt.unroll_token = unroll_token;
+ result->UnrollRangeStmt.init = init;
result->UnrollRangeStmt.args = args;
result->UnrollRangeStmt.for_token = for_token;
result->UnrollRangeStmt.val0 = val0;
@@ -4883,7 +4887,7 @@ gb_internal Ast *parse_for_stmt(AstFile *f) {
body = parse_block_stmt(f, false);
}
- return ast_range_stmt(f, token, {}, in_token, rhs, body);
+ return ast_range_stmt(f, token, init, {}, in_token, rhs, body);
}
if (f->curr_token.kind != Token_Semicolon) {
@@ -4898,9 +4902,20 @@ gb_internal Ast *parse_for_stmt(AstFile *f) {
cond = nullptr;
if (f->curr_token.kind == Token_OpenBrace || f->curr_token.kind == Token_do) {
- syntax_error(f->curr_token, "Expected ';', followed by a condition expression and post statement, got %.*s", LIT(token_strings[f->curr_token.kind]));
+ syntax_error(f->curr_token, "Expected ';', followed by a condition expression and post statement, or 'x in y' style loop, got %.*s", LIT(token_strings[f->curr_token.kind]));
} else {
if (f->curr_token.kind != Token_Semicolon) {
+ if (f->curr_token.kind == Token_Ident) {
+ // for init; x in y { }
+ Token next_token = peek_token(f);
+ if (next_token.kind == Token_in || next_token.kind == Token_Comma) {
+ cond = parse_simple_stmt(f, StmtAllowFlag_In);
+ GB_ASSERT(cond->kind == Ast_AssignStmt && cond->AssignStmt.op.kind == Token_in);
+ is_range = true;
+ goto range_skip;
+ }
+ }
+
cond = parse_simple_stmt(f, StmtAllowFlag_None);
}
@@ -4918,6 +4933,7 @@ gb_internal Ast *parse_for_stmt(AstFile *f) {
}
}
+range_skip:;
if (allow_token(f, Token_do)) {
body = parse_do_body(f, token, "the for statement");
@@ -4933,7 +4949,7 @@ gb_internal Ast *parse_for_stmt(AstFile *f) {
if (cond->AssignStmt.rhs.count > 0) {
rhs = cond->AssignStmt.rhs[0];
}
- return ast_range_stmt(f, token, vals, in_token, rhs, body);
+ return ast_range_stmt(f, token, init, vals, in_token, rhs, body);
}
cond = convert_stmt_to_expr(f, cond, str_lit("boolean expression"));
@@ -5267,6 +5283,7 @@ gb_internal Ast *parse_unrolled_for_loop(AstFile *f, Token unroll_token) {
}
Token for_token = expect_token(f, Token_for);
+ Ast *init = nullptr;
Ast *val0 = nullptr;
Ast *val1 = nullptr;
Token in_token = {};
@@ -5309,7 +5326,7 @@ gb_internal Ast *parse_unrolled_for_loop(AstFile *f, Token unroll_token) {
if (bad_stmt) {
return ast_bad_stmt(f, unroll_token, f->curr_token);
}
- return ast_unroll_range_stmt(f, unroll_token, slice_from_array(args), for_token, val0, val1, in_token, expr, body);
+ return ast_unroll_range_stmt(f, unroll_token, init, slice_from_array(args), for_token, val0, val1, in_token, expr, body);
}
gb_internal Ast *parse_stmt(AstFile *f) {
diff --git a/src/parser.hpp b/src/parser.hpp
index 1026433d0..d3527285d 100644
--- a/src/parser.hpp
+++ b/src/parser.hpp
@@ -587,6 +587,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \
Scope *scope; \
Token token; \
Ast *label; \
+ Ast *init; \
Slice<Ast *> vals; \
Token in_token; \
Ast *expr; \
@@ -596,6 +597,7 @@ AST_KIND(_ComplexStmtBegin, "", bool) \
AST_KIND(UnrollRangeStmt, "#unroll range statement", struct { \
Scope *scope; \
Token unroll_token; \
+ Ast *init; \
Slice<Ast *> args; \
Token for_token; \
Ast *val0; \