From 4c655865e5d9af83a98c137609b01972f4e51beb Mon Sep 17 00:00:00 2001 From: gingerBill Date: Mon, 18 Oct 2021 16:52:19 +0100 Subject: Begin work on matrix type --- src/checker.cpp | 2 ++ 1 file changed, 2 insertions(+) (limited to 'src/checker.cpp') diff --git a/src/checker.cpp b/src/checker.cpp index d3c0080de..8711fdc0c 100644 --- a/src/checker.cpp +++ b/src/checker.cpp @@ -2458,6 +2458,7 @@ void init_core_type_info(Checker *c) { t_type_info_simd_vector = find_core_type(c, str_lit("Type_Info_Simd_Vector")); t_type_info_relative_pointer = find_core_type(c, str_lit("Type_Info_Relative_Pointer")); t_type_info_relative_slice = find_core_type(c, str_lit("Type_Info_Relative_Slice")); + t_type_info_matrix = find_core_type(c, str_lit("Type_Info_Matrix")); t_type_info_named_ptr = alloc_type_pointer(t_type_info_named); t_type_info_integer_ptr = alloc_type_pointer(t_type_info_integer); @@ -2485,6 +2486,7 @@ void init_core_type_info(Checker *c) { t_type_info_simd_vector_ptr = alloc_type_pointer(t_type_info_simd_vector); t_type_info_relative_pointer_ptr = alloc_type_pointer(t_type_info_relative_pointer); t_type_info_relative_slice_ptr = alloc_type_pointer(t_type_info_relative_slice); + t_type_info_matrix_ptr = alloc_type_pointer(t_type_info_matrix); } void init_mem_allocator(Checker *c) { -- cgit v1.2.3 From ba331024af2f5074125442e91dda6c8e63324c8f Mon Sep 17 00:00:00 2001 From: gingerBill Date: Mon, 18 Oct 2021 18:16:52 +0100 Subject: Very basic matrix support in backend --- core/fmt/fmt.odin | 35 +++++++++- src/check_expr.cpp | 153 ++++++++++++++++++++++++++++++++++++++++++- src/checker.cpp | 8 +++ src/llvm_backend.hpp | 4 ++ src/llvm_backend_const.cpp | 28 ++++++++ src/llvm_backend_expr.cpp | 78 ++++++++++++++++++++++ src/llvm_backend_utility.cpp | 35 ++++++++++ src/types.cpp | 31 +++++++-- 8 files changed, 364 insertions(+), 8 deletions(-) (limited to 'src/checker.cpp') diff --git a/core/fmt/fmt.odin b/core/fmt/fmt.odin index cee00da23..804a29cab 100644 --- a/core/fmt/fmt.odin +++ b/core/fmt/fmt.odin @@ -1954,7 +1954,40 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { } case runtime.Type_Info_Matrix: - io.write_string(fi.writer, "[]") + reflect.write_type(fi.writer, type_info_of(v.id)) + io.write_byte(fi.writer, '{') + defer io.write_byte(fi.writer, '}') + + fi.indent += 1; defer fi.indent -= 1 + + if fi.hash { + io.write_byte(fi.writer, '\n') + // TODO(bill): Should this render it like in written form? e.g. tranposed + for col in 0.. 0 { io.write_string(fi.writer, ", ") } + + offset := row*info.elem_size + col*info.stride + + data := uintptr(v.data) + uintptr(offset) + fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) + } + io.write_string(fi.writer, ";\n") + } + } else { + for col in 0.. 0 { io.write_string(fi.writer, "; ") } + for row in 0.. 0 { io.write_string(fi.writer, ", ") } + + offset := row*info.elem_size + col*info.stride + + data := uintptr(v.data) + uintptr(offset) + fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) + } + } + } } } diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 85f2eeb23..9c12802d7 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -1400,8 +1400,9 @@ bool check_unary_op(CheckerContext *c, Operand *o, Token op) { } bool check_binary_op(CheckerContext *c, Operand *o, Token op) { + Type *main_type = o->type; // TODO(bill): Handle errors correctly - Type *type = base_type(core_array_type(o->type)); + Type *type = base_type(core_array_type(main_type)); Type *ct = core_type(type); switch (op.kind) { case Token_Sub: @@ -1414,10 +1415,15 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { } break; - case Token_Mul: case Token_Quo: - case Token_MulEq: case Token_QuoEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } + /*fallthrough*/ + case Token_Mul: + case Token_MulEq: case Token_AddEq: if (is_type_bit_set(type)) { return true; @@ -1458,6 +1464,10 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { case Token_ModMod: case Token_ModEq: case Token_ModModEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } if (!is_type_integer(type)) { error(op, "Operator '%.*s' is only allowed with integers", LIT(op.string)); return false; @@ -2671,6 +2681,114 @@ bool can_use_other_type_as_type_hint(bool use_lhs_as_type_hint, Type *other_type } +void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand *y, Type *type_hint, bool use_lhs_as_type_hint) { + if (!check_binary_op(c, x, op)) { + x->mode = Addressing_Invalid; + return; + } + + if (is_type_matrix(x->type)) { + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); + GB_ASSERT(xt->kind == Type_Matrix); + if (op.kind == Token_Mul) { + if (yt->kind == Type_Matrix) { + if (!are_types_identical(xt->Matrix.elem, yt->Matrix.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Matrix.row_count) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count); + goto matrix_success; + } else if (yt->kind == Type_Array) { + if (!are_types_identical(xt->Matrix.elem, yt->Array.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Array.count) { + goto matrix_error; + } + + // Treat arrays as column vectors + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1); + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } else { + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); + GB_ASSERT(is_type_matrix(yt)); + GB_ASSERT(!is_type_matrix(xt)); + + if (op.kind == Token_Mul) { + // NOTE(bill): no need to handle the matrix case here since it should be handled above + if (xt->kind == Type_Array) { + if (!are_types_identical(yt->Matrix.elem, xt->Array.elem)) { + goto matrix_error; + } + + if (xt->Array.count != yt->Matrix.row_count) { + goto matrix_error; + } + + // Treat arrays as row vectors + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, 1, xt->Matrix.column_count); + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } + +matrix_success: + if (type_hint != nullptr) { + Type *th = base_type(type_hint); + if (are_types_identical(th, x->type)) { + x->type = type_hint; + } else if (x->type->kind == Type_Matrix && th->kind == Type_Array) { + Type *xt = x->type; + if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) { + // ignore + } else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) { + x->type = type_hint; + } else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) { + x->type = type_hint; + } + } + } + return; + + +matrix_error: + gbString xt = type_to_string(x->type); + gbString yt = type_to_string(y->type); + gbString expr_str = expr_to_string(x->expr); + error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xt, yt); + gb_string_free(expr_str); + gb_string_free(yt); + gb_string_free(xt); + x->type = t_invalid; + x->mode = Addressing_Invalid; + return; + +} + + void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint, bool use_lhs_as_type_hint=false) { GB_ASSERT(node->kind == Ast_BinaryExpr); Operand y_ = {}, *y = &y_; @@ -2874,6 +2992,12 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint x->type = y->type; return; } + if (is_type_matrix(x->type) || is_type_matrix(y->type)) { + check_binary_matrix(c, op, x, y, type_hint, use_lhs_as_type_hint); + return; + } + + if (!are_types_identical(x->type, y->type)) { if (x->type != t_invalid && y->type != t_invalid) { @@ -3258,6 +3382,29 @@ void convert_to_typed(CheckerContext *c, Operand *operand, Type *target_type) { break; } + + case Type_Matrix: { + Type *elem = base_array_type(t); + if (check_is_assignable_to(c, operand, elem)) { + if (t->Matrix.row_count != t->Matrix.column_count) { + operand->mode = Addressing_Invalid; + begin_error_block(); + defer (end_error_block()); + + convert_untyped_error(c, operand, target_type); + error_line("\tNote: Only a square matrix types can be initialized with a scalar value\n"); + return; + } else { + operand->mode = Addressing_Value; + } + } else { + operand->mode = Addressing_Invalid; + convert_untyped_error(c, operand, target_type); + return; + } + break; + } + case Type_Union: if (!is_operand_nil(*operand) && !is_operand_undef(*operand)) { diff --git a/src/checker.cpp b/src/checker.cpp index 8711fdc0c..c0e6d47c0 100644 --- a/src/checker.cpp +++ b/src/checker.cpp @@ -1659,6 +1659,10 @@ void add_type_info_type_internal(CheckerContext *c, Type *t) { add_type_info_type_internal(c, bt->RelativeSlice.slice_type); add_type_info_type_internal(c, bt->RelativeSlice.base_integer); break; + + case Type_Matrix: + add_type_info_type_internal(c, bt->Matrix.elem); + break; default: GB_PANIC("Unhandled type: %*.s %d", LIT(type_strings[bt->kind]), bt->kind); @@ -1870,6 +1874,10 @@ void add_min_dep_type_info(Checker *c, Type *t) { add_min_dep_type_info(c, bt->RelativeSlice.slice_type); add_min_dep_type_info(c, bt->RelativeSlice.base_integer); break; + + case Type_Matrix: + add_min_dep_type_info(c, bt->Matrix.elem); + break; default: GB_PANIC("Unhandled type: %*.s", LIT(type_strings[bt->kind])); diff --git a/src/llvm_backend.hpp b/src/llvm_backend.hpp index ffb81f0e4..73ddad797 100644 --- a/src/llvm_backend.hpp +++ b/src/llvm_backend.hpp @@ -333,6 +333,10 @@ lbValue lb_emit_array_ep(lbProcedure *p, lbValue s, lbValue index); lbValue lb_emit_deep_field_gep(lbProcedure *p, lbValue e, Selection sel); lbValue lb_emit_deep_field_ev(lbProcedure *p, lbValue e, Selection sel); +lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column); +lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column); + + lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type); lbValue lb_emit_byte_swap(lbProcedure *p, lbValue value, Type *end_type); void lb_emit_defer_stmts(lbProcedure *p, lbDeferExitKind kind, lbBlock *block); diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 68050e0ce..4cfcecdc3 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -512,6 +512,34 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc res.value = llvm_const_array(lb_type(m, elem), elems, cast(unsigned)count); return res; + } else if (is_type_matrix(type) && + value.kind != ExactValue_Invalid && + value.kind != ExactValue_Compound) { + i64 row = type->Matrix.row_count; + i64 column = type->Matrix.column_count; + GB_ASSERT(row == column); + + Type *elem = type->Matrix.elem; + + lbValue single_elem = lb_const_value(m, elem, value, allow_local); + single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem)); + + i64 stride_bytes = matrix_type_stride(type); + i64 stride_elems = stride_bytes/type_size_of(elem); + + i64 total_elem_count = matrix_type_total_elems(type); + LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); + for (i64 i = 0; i < row; i++) { + elems[i*stride_elems + i] = single_elem.value; + } + for (i64 i = 0; i < total_elem_count; i++) { + if (elems[i] == nullptr) { + elems[i] = LLVMConstNull(lb_type(m, elem)); + } + } + + res.value = LLVMConstArray(lb_type(m, elem), elems, cast(unsigned)total_elem_count); + return res; } switch (value.kind) { diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 3056952f6..6b7d90ec0 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -477,10 +477,72 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r } +lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { + GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); + + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + if (op == Token_Mul) { + if (xt->kind == Type_Matrix) { + if (yt->kind == Type_Matrix) { + GB_ASSERT(is_type_matrix(type)); + GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); + GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); + + Type *elem = xt->Matrix.elem; + + lbAddr res = lb_add_local_generated(p, type, true); + for (i64 i = 0; i < xt->Matrix.row_count; i++) { + for (i64 j = 0; j < yt->Matrix.column_count; j++) { + for (i64 k = 0; k < xt->Matrix.column_count; k++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + + lbValue a = lb_emit_matrix_ev(p, lhs, i, k); + lbValue b = lb_emit_matrix_ev(p, rhs, k, j); + lbValue c = lb_emit_arith(p, op, a, b, elem); + lbValue d = lb_emit_load(p, dst); + lbValue e = lb_emit_arith(p, Token_Add, d, c, elem); + lb_emit_store(p, dst, e); + + } + } + } + + return lb_addr_load(p, res); + } + } + + } else { + GB_ASSERT(are_types_identical(xt, yt)); + GB_ASSERT(xt->kind == Type_Matrix); + // element-wise arithmetic + // pretend it is an array + lbValue array_lhs = lhs; + lbValue array_rhs = rhs; + Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt)); + GB_ASSERT(type_size_of(array_type) == type_size_of(type)); + + array_lhs.type = array_type; + array_rhs.type = array_type; + + lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, type); + array.type = type; + return array; + } + + GB_PANIC("TODO: lb_emit_arith_matrix"); + + return {}; +} + + lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) { return lb_emit_arith_array(p, op, lhs, rhs, type); + } else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) { + return lb_emit_arith_matrix(p, op, lhs, rhs, type); } else if (is_type_complex(type)) { lhs = lb_emit_conv(p, lhs, type); rhs = lb_emit_conv(p, rhs, type); @@ -1417,6 +1479,22 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { } return lb_addr_load(p, v); } + + if (is_type_matrix(dst) && !is_type_matrix(src)) { + GB_ASSERT(dst->Matrix.row_count == dst->Matrix.column_count); + + Type *elem = base_array_type(dst); + lbValue e = lb_emit_conv(p, value, elem); + lbAddr v = lb_add_local_generated(p, t, false); + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + isize j = cast(isize)i; + lbValue ptr = lb_emit_matrix_epi(p, v.addr, j, j); + lb_emit_store(p, ptr, e); + } + + + return lb_addr_load(p, v); + } if (is_type_any(dst)) { if (is_type_untyped_nil(src)) { diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index 0531c62bb..1b41be2a3 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1221,6 +1221,41 @@ lbValue lb_emit_ptr_offset(lbProcedure *p, lbValue ptr, lbValue index) { return res; } +lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) { + Type *t = s.type; + GB_ASSERT(is_type_pointer(t)); + Type *st = base_type(type_deref(t)); + GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st)); + + Type *ptr = base_array_type(st); + + isize index = row*column; + GB_ASSERT(0 <= index); + + LLVMValueRef indices[2] = { + LLVMConstInt(lb_type(p->module, t_int), 0, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)index, false), + }; + + lbValue res = {}; + if (lb_is_const(s)) { + res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices)); + } else { + res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), ""); + } + res.type = alloc_type_pointer(ptr); + return res; +} + +lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) { + Type *st = base_type(s.type); + GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st)); + + lbValue value = lb_address_from_load_or_generate_local(p, s); + lbValue ptr = lb_emit_matrix_epi(p, value, row, column); + return lb_emit_load(p, ptr); +} + void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len) { Type *t = lb_addr_type(slice); diff --git a/src/types.cpp b/src/types.cpp index 0313ade60..fd9b20c91 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1257,6 +1257,22 @@ i64 matrix_type_stride(Type *t) { return stride; } +i64 matrix_type_stride_in_elems(Type *t) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + i64 stride = matrix_type_stride(t); + return stride/gb_max(1, type_size_of(t->Matrix.elem)); +} + + +i64 matrix_type_total_elems(Type *t) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + i64 size = type_size_of(t); + i64 elem_size = type_size_of(t->Matrix.elem); + return size/gb_max(elem_size, 1); +} + bool is_type_dynamic_array(Type *t) { t = base_type(t); return t->kind == Type_DynamicArray; @@ -3174,17 +3190,17 @@ i64 type_align_of_internal(Type *t, TypePath *path) { case Type_Matrix: { Type *elem = t->Matrix.elem; - i64 row_count = t->Matrix.row_count; - // i64 column_count = t->Matrix.column_count; + i64 row_count = gb_max(t->Matrix.row_count, 1); + bool pop = type_path_push(path, elem); if (path->failure) { return FAILURE_ALIGNMENT; } + // elem align is used here rather than size as it make a little more sense i64 elem_align = type_align_of_internal(elem, path); if (pop) type_path_pop(path); - i64 align = gb_clamp(elem_align * row_count, elem_align, build_context.max_align); - + i64 align = gb_min(next_pow2(elem_align * row_count), build_context.max_align); return align; } @@ -3935,6 +3951,13 @@ gbString write_type_to_string(gbString str, Type *type) { str = gb_string_append_fmt(str, ") "); str = write_type_to_string(str, type->RelativeSlice.slice_type); break; + + case Type_Matrix: + str = gb_string_appendc(str, gb_bprintf("[%d", cast(int)type->Matrix.row_count)); + str = gb_string_appendc(str, "; "); + str = gb_string_appendc(str, gb_bprintf("%d]", cast(int)type->Matrix.column_count)); + str = write_type_to_string(str, type->Matrix.elem); + break; } return str; -- cgit v1.2.3 From 662cbaf425a54127dea206c3a35d776853bac169 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Tue, 19 Oct 2021 12:13:19 +0100 Subject: Support indexing matrices --- core/runtime/error_checks.odin | 23 +++++++++++++++ src/check_expr.cpp | 66 ++++++++++++++++++++++++++++++++++++++++-- src/checker.cpp | 1 + src/llvm_backend.hpp | 1 + src/llvm_backend_expr.cpp | 54 +++++++++++++++++++++++++++++++++- src/llvm_backend_general.cpp | 30 +++++++++++++++++++ src/llvm_backend_utility.cpp | 31 ++++++++++++++++++++ src/types.cpp | 4 +++ 8 files changed, 206 insertions(+), 4 deletions(-) (limited to 'src/checker.cpp') diff --git a/core/runtime/error_checks.odin b/core/runtime/error_checks.odin index bdd010b50..7f1aeb2d7 100644 --- a/core/runtime/error_checks.odin +++ b/core/runtime/error_checks.odin @@ -96,6 +96,29 @@ dynamic_array_expr_error :: proc "contextless" (file: string, line, column: i32, } +matrix_bounds_check_error :: proc "contextless" (file: string, line, column: i32, row_index, column_index, row_count, column_count: int) { + if 0 <= row_index && row_index < row_count && + 0 <= column_index && column_index < column_count { + return + } + handle_error :: proc "contextless" (file: string, line, column: i32, row_index, column_index, row_count, column_count: int) { + print_caller_location(Source_Code_Location{file, line, column, ""}) + print_string(" Matrix indices [") + print_i64(i64(row_index)) + print_string(", ") + print_i64(i64(column_index)) + print_string(" is out of bounds range [0..<") + print_i64(i64(row_count)) + print_string(", 0..<") + print_i64(i64(column_count)) + print_string("]") + print_byte('\n') + bounds_trap() + } + handle_error(file, line, column, row_index, column_index, row_count, column_count) +} + + type_assertion_check :: proc "contextless" (ok: bool, file: string, line, column: i32, from, to: typeid) { if ok { return diff --git a/src/check_expr.cpp b/src/check_expr.cpp index a75334e6c..73e1a7e51 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -6367,8 +6367,7 @@ bool check_set_index_data(Operand *o, Type *t, bool indirection, i64 *max_count, *max_count = t->Matrix.column_count; if (indirection) { o->mode = Addressing_Variable; - } else if (o->mode != Addressing_Variable && - o->mode != Addressing_Constant) { + } else if (o->mode != Addressing_Variable) { o->mode = Addressing_Value; } o->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count); @@ -6672,7 +6671,68 @@ void check_promote_optional_ok(CheckerContext *c, Operand *x, Type **val_type_, void check_matrix_index_expr(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) { - error(node, "TODO: matrix index expressions"); + ast_node(ie, MatrixIndexExpr, node); + + check_expr(c, o, ie->expr); + node->viral_state_flags |= ie->expr->viral_state_flags; + if (o->mode == Addressing_Invalid) { + o->expr = node; + return; + } + + Type *t = base_type(type_deref(o->type)); + bool is_ptr = is_type_pointer(o->type); + bool is_const = o->mode == Addressing_Constant; + + if (t->kind != Type_Matrix) { + gbString str = expr_to_string(o->expr); + gbString type_str = type_to_string(o->type); + defer (gb_string_free(str)); + defer (gb_string_free(type_str)); + if (is_const) { + error(o->expr, "Cannot use matrix indexing on constant '%s' of type '%s'", str, type_str); + } else { + error(o->expr, "Cannot use matrix indexing on '%s' of type '%s'", str, type_str); + } + o->mode = Addressing_Invalid; + o->expr = node; + return; + } + o->type = t->Matrix.elem; + if (is_ptr) { + o->mode = Addressing_Variable; + } else if (o->mode != Addressing_Variable) { + o->mode = Addressing_Value; + } + + if (ie->row_index == nullptr) { + gbString str = expr_to_string(o->expr); + error(o->expr, "Missing row index for '%s'", str); + gb_string_free(str); + o->mode = Addressing_Invalid; + o->expr = node; + return; + } + if (ie->column_index == nullptr) { + gbString str = expr_to_string(o->expr); + error(o->expr, "Missing column index for '%s'", str); + gb_string_free(str); + o->mode = Addressing_Invalid; + o->expr = node; + return; + } + + i64 row_count = t->Matrix.row_count; + i64 column_count = t->Matrix.column_count; + + i64 row_index = 0; + i64 column_index = 0; + bool row_ok = check_index_value(c, t, false, ie->row_index, row_count, &row_index, nullptr); + bool column_ok = check_index_value(c, t, false, ie->column_index, column_count, &column_index, nullptr); + + + gb_unused(row_ok); + gb_unused(column_ok); } diff --git a/src/checker.cpp b/src/checker.cpp index c0e6d47c0..23597167b 100644 --- a/src/checker.cpp +++ b/src/checker.cpp @@ -2022,6 +2022,7 @@ void generate_minimum_dependency_set(Checker *c, Entity *start) { String bounds_check_entities[] = { // Bounds checking related procedures str_lit("bounds_check_error"), + str_lit("matrix_bounds_check_error"), str_lit("slice_expr_error_hi"), str_lit("slice_expr_error_lo_hi"), str_lit("multi_pointer_slice_expr_error"), diff --git a/src/llvm_backend.hpp b/src/llvm_backend.hpp index 73ddad797..9041e7621 100644 --- a/src/llvm_backend.hpp +++ b/src/llvm_backend.hpp @@ -333,6 +333,7 @@ lbValue lb_emit_array_ep(lbProcedure *p, lbValue s, lbValue index); lbValue lb_emit_deep_field_gep(lbProcedure *p, lbValue e, Selection sel); lbValue lb_emit_deep_field_ev(lbProcedure *p, lbValue e, Selection sel); +lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lbValue column); lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column); lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column); diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index ed98c6845..bcbb77355 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -1727,7 +1727,7 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { } if (is_type_matrix(dst) && !is_type_matrix(src)) { - GB_ASSERT(dst->Matrix.row_count == dst->Matrix.column_count); + GB_ASSERT_MSG(dst->Matrix.row_count == dst->Matrix.column_count, "%s <- %s", type_to_string(dst), type_to_string(src)); Type *elem = base_array_type(dst); lbValue e = lb_emit_conv(p, value, elem); @@ -2805,6 +2805,10 @@ lbValue lb_build_expr(lbProcedure *p, Ast *expr) { case_ast_node(ie, IndexExpr, expr); return lb_addr_load(p, lb_build_addr(p, expr)); case_end; + + case_ast_node(ie, MatrixIndexExpr, expr); + return lb_addr_load(p, lb_build_addr(p, expr)); + case_end; case_ast_node(ia, InlineAsmExpr, expr); Type *t = type_of_expr(expr); @@ -3304,6 +3308,25 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { lbValue v = lb_emit_ptr_offset(p, elem, index); return lb_addr(v); } + + case Type_Matrix: { + lbValue matrix = {}; + matrix = lb_build_addr_ptr(p, ie->expr); + if (deref) { + matrix = lb_emit_load(p, matrix); + } + lbValue index = lb_build_expr(p, ie->index); + index = lb_emit_conv(p, index, t_int); + lbValue elem = lb_emit_matrix_ep(p, matrix, lb_const_int(p->module, t_int, 0), index); + elem = lb_emit_conv(p, elem, alloc_type_pointer(type_of_expr(expr))); + + auto index_tv = type_and_value_of_expr(ie->index); + if (index_tv.mode != Addressing_Constant) { + lbValue len = lb_const_int(p->module, t_int, t->Matrix.column_count); + lb_emit_bounds_check(p, ast_token(ie->index), index, len); + } + return lb_addr(elem); + } case Type_Basic: { // Basic_string @@ -3326,6 +3349,35 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { } } case_end; + + case_ast_node(ie, MatrixIndexExpr, expr); + Type *t = base_type(type_of_expr(ie->expr)); + + bool deref = is_type_pointer(t); + t = base_type(type_deref(t)); + + lbValue m = {}; + m = lb_build_addr_ptr(p, ie->expr); + if (deref) { + m = lb_emit_load(p, m); + } + lbValue row_index = lb_build_expr(p, ie->row_index); + lbValue column_index = lb_build_expr(p, ie->column_index); + row_index = lb_emit_conv(p, row_index, t_int); + column_index = lb_emit_conv(p, column_index, t_int); + lbValue elem = lb_emit_matrix_ep(p, m, row_index, column_index); + + auto row_index_tv = type_and_value_of_expr(ie->row_index); + auto column_index_tv = type_and_value_of_expr(ie->column_index); + if (row_index_tv.mode != Addressing_Constant || column_index_tv.mode != Addressing_Constant) { + lbValue row_count = lb_const_int(p->module, t_int, t->Matrix.row_count); + lbValue column_count = lb_const_int(p->module, t_int, t->Matrix.column_count); + lb_emit_matrix_bounds_check(p, ast_token(ie->row_index), row_index, column_index, row_count, column_count); + } + return lb_addr(elem); + + + case_end; case_ast_node(se, SliceExpr, expr); diff --git a/src/llvm_backend_general.cpp b/src/llvm_backend_general.cpp index 63a63349a..01221cad6 100644 --- a/src/llvm_backend_general.cpp +++ b/src/llvm_backend_general.cpp @@ -419,6 +419,36 @@ void lb_emit_bounds_check(lbProcedure *p, Token token, lbValue index, lbValue le lb_emit_runtime_call(p, "bounds_check_error", args); } +void lb_emit_matrix_bounds_check(lbProcedure *p, Token token, lbValue row_index, lbValue column_index, lbValue row_count, lbValue column_count) { + if (build_context.no_bounds_check) { + return; + } + if ((p->state_flags & StateFlag_no_bounds_check) != 0) { + return; + } + + row_index = lb_emit_conv(p, row_index, t_int); + column_index = lb_emit_conv(p, column_index, t_int); + row_count = lb_emit_conv(p, row_count, t_int); + column_count = lb_emit_conv(p, column_count, t_int); + + lbValue file = lb_find_or_add_entity_string(p->module, get_file_path_string(token.pos.file_id)); + lbValue line = lb_const_int(p->module, t_i32, token.pos.line); + lbValue column = lb_const_int(p->module, t_i32, token.pos.column); + + auto args = array_make(permanent_allocator(), 7); + args[0] = file; + args[1] = line; + args[2] = column; + args[3] = row_index; + args[4] = column_index; + args[5] = row_count; + args[6] = column_count; + + lb_emit_runtime_call(p, "matrix_bounds_check_error", args); +} + + void lb_emit_multi_pointer_slice_bounds_check(lbProcedure *p, Token token, lbValue low, lbValue high) { if (build_context.no_bounds_check) { return; diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index 3971c0ca6..c7e9e1742 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1249,6 +1249,37 @@ lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) { return res; } +lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lbValue column) { + Type *t = s.type; + GB_ASSERT(is_type_pointer(t)); + Type *mt = base_type(type_deref(t)); + GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt)); + + Type *ptr = base_array_type(mt); + + LLVMValueRef stride_elems = lb_const_int(p->module, t_int, matrix_type_stride_in_elems(mt)).value; + + row = lb_emit_conv(p, row, t_int); + column = lb_emit_conv(p, column, t_int); + + LLVMValueRef index = LLVMBuildAdd(p->builder, row.value, LLVMBuildMul(p->builder, column.value, stride_elems, ""), ""); + + LLVMValueRef indices[2] = { + LLVMConstInt(lb_type(p->module, t_int), 0, false), + index, + }; + + lbValue res = {}; + if (lb_is_const(s)) { + res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices)); + } else { + res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), ""); + } + res.type = alloc_type_pointer(ptr); + return res; +} + + lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) { Type *st = base_type(s.type); GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st)); diff --git a/src/types.cpp b/src/types.cpp index 8e64a10c1..ec094b4ff 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1726,6 +1726,8 @@ bool is_type_indexable(Type *t) { return true; case Type_RelativeSlice: return true; + case Type_Matrix: + return true; } return false; } @@ -1743,6 +1745,8 @@ bool is_type_sliceable(Type *t) { return false; case Type_RelativeSlice: return true; + case Type_Matrix: + return false; } return false; } -- cgit v1.2.3