From ba331024af2f5074125442e91dda6c8e63324c8f Mon Sep 17 00:00:00 2001 From: gingerBill Date: Mon, 18 Oct 2021 18:16:52 +0100 Subject: Very basic matrix support in backend --- src/llvm_backend_const.cpp | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) (limited to 'src/llvm_backend_const.cpp') diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 68050e0ce..4cfcecdc3 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -512,6 +512,34 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc res.value = llvm_const_array(lb_type(m, elem), elems, cast(unsigned)count); return res; + } else if (is_type_matrix(type) && + value.kind != ExactValue_Invalid && + value.kind != ExactValue_Compound) { + i64 row = type->Matrix.row_count; + i64 column = type->Matrix.column_count; + GB_ASSERT(row == column); + + Type *elem = type->Matrix.elem; + + lbValue single_elem = lb_const_value(m, elem, value, allow_local); + single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem)); + + i64 stride_bytes = matrix_type_stride(type); + i64 stride_elems = stride_bytes/type_size_of(elem); + + i64 total_elem_count = matrix_type_total_elems(type); + LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); + for (i64 i = 0; i < row; i++) { + elems[i*stride_elems + i] = single_elem.value; + } + for (i64 i = 0; i < total_elem_count; i++) { + if (elems[i] == nullptr) { + elems[i] = LLVMConstNull(lb_type(m, elem)); + } + } + + res.value = LLVMConstArray(lb_type(m, elem), elems, cast(unsigned)total_elem_count); + return res; } switch (value.kind) { -- cgit v1.2.3 From 82b6772ea4fa9872a1fb98305814be8cf7f2c7c4 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 00:40:03 +0100 Subject: Support matrix literals --- core/fmt/fmt.odin | 4 +- core/runtime/core.odin | 2 +- src/check_expr.cpp | 5 ++ src/llvm_backend.hpp | 2 + src/llvm_backend_const.cpp | 83 ++++++++++++++++++++++++++++-- src/llvm_backend_expr.cpp | 119 ++++++++++++++++++++++++++++++++++++++++--- src/llvm_backend_type.cpp | 2 +- src/llvm_backend_utility.cpp | 78 +++++++++++++++++++++++++--- src/parser.cpp | 1 + src/types.cpp | 33 ++++++++++++ 10 files changed, 306 insertions(+), 23 deletions(-) (limited to 'src/llvm_backend_const.cpp') diff --git a/core/fmt/fmt.odin b/core/fmt/fmt.odin index c0190a0b9..dc5b529ea 100644 --- a/core/fmt/fmt.odin +++ b/core/fmt/fmt.odin @@ -1967,7 +1967,7 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { for col in 0.. 0 { io.write_string(fi.writer, ", ") } - offset := row*info.elem_size + col*info.stride + offset := (row + col*info.elem_stride)*info.elem_size data := uintptr(v.data) + uintptr(offset) fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) @@ -1980,7 +1980,7 @@ fmt_value :: proc(fi: ^Info, v: any, verb: rune) { for col in 0.. 0 { io.write_string(fi.writer, "; ") } - offset := row*info.elem_size + col*info.stride + offset := (row + col*info.elem_stride)*info.elem_size data := uintptr(v.data) + uintptr(offset) fmt_arg(fi, any{rawptr(data), info.elem.id}, verb) diff --git a/core/runtime/core.odin b/core/runtime/core.odin index 611b4002c..ba1e81da6 100644 --- a/core/runtime/core.odin +++ b/core/runtime/core.odin @@ -165,7 +165,7 @@ Type_Info_Relative_Slice :: struct { Type_Info_Matrix :: struct { elem: ^Type_Info, elem_size: int, - stride: int, // bytes + elem_stride: int, row_count: int, column_count: int, } diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 73e1a7e51..eb6040320 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -7369,6 +7369,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type case Type_Array: case Type_DynamicArray: case Type_SimdVector: + case Type_Matrix: { Type *elem_type = nullptr; String context_name = {}; @@ -7395,6 +7396,10 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type elem_type = t->SimdVector.elem; context_name = str_lit("simd vector literal"); max_type_count = t->SimdVector.count; + } else if (t->kind == Type_Matrix) { + elem_type = t->Matrix.elem; + context_name = str_lit("matrix literal"); + max_type_count = t->Matrix.row_count*t->Matrix.column_count; } else { GB_PANIC("unreachable"); } diff --git a/src/llvm_backend.hpp b/src/llvm_backend.hpp index 9041e7621..d2abed354 100644 --- a/src/llvm_backend.hpp +++ b/src/llvm_backend.hpp @@ -393,6 +393,8 @@ lbValue lb_soa_struct_len(lbProcedure *p, lbValue value); void lb_emit_increment(lbProcedure *p, lbValue addr); lbValue lb_emit_select(lbProcedure *p, lbValue cond, lbValue x, lbValue y); +lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t); + void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len); lbValue lb_type_info(lbModule *m, Type *type); diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 4cfcecdc3..413fb365b 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -523,14 +523,11 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc lbValue single_elem = lb_const_value(m, elem, value, allow_local); single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem)); - - i64 stride_bytes = matrix_type_stride(type); - i64 stride_elems = stride_bytes/type_size_of(elem); - + i64 total_elem_count = matrix_type_total_elems(type); LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); for (i64 i = 0; i < row; i++) { - elems[i*stride_elems + i] = single_elem.value; + elems[matrix_index_to_offset(type, i)] = single_elem.value; } for (i64 i = 0; i < total_elem_count; i++) { if (elems[i] == nullptr) { @@ -984,6 +981,82 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc res.value = LLVMConstInt(lb_type(m, original_type), bits, false); return res; + } else if (is_type_matrix(type)) { + ast_node(cl, CompoundLit, value.value_compound); + Type *elem_type = type->Matrix.elem; + isize elem_count = cl->elems.count; + if (elem_count == 0 || !elem_type_can_be_constant(elem_type)) { + return lb_const_nil(m, original_type); + } + + i64 max_count = type->Matrix.row_count*type->Matrix.column_count; + i64 total_count = matrix_type_total_elems(type); + + LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count); + if (cl->elems[0]->kind == Ast_FieldValue) { + for_array(j, cl->elems) { + Ast *elem = cl->elems[j]; + ast_node(fv, FieldValue, elem); + if (is_ast_range(fv->field)) { + ast_node(ie, BinaryExpr, fv->field); + TypeAndValue lo_tav = ie->left->tav; + TypeAndValue hi_tav = ie->right->tav; + GB_ASSERT(lo_tav.mode == Addressing_Constant); + GB_ASSERT(hi_tav.mode == Addressing_Constant); + + TokenKind op = ie->op.kind; + i64 lo = exact_value_to_i64(lo_tav.value); + i64 hi = exact_value_to_i64(hi_tav.value); + if (op != Token_RangeHalf) { + hi += 1; + } + TypeAndValue tav = fv->value->tav; + LLVMValueRef val = lb_const_value(m, elem_type, tav.value, allow_local).value; + for (i64 k = lo; k < hi; k++) { + i64 offset = matrix_index_to_offset(type, k); + GB_ASSERT(values[offset] == nullptr); + values[offset] = val; + } + } else { + TypeAndValue index_tav = fv->field->tav; + GB_ASSERT(index_tav.mode == Addressing_Constant); + i64 index = exact_value_to_i64(index_tav.value); + TypeAndValue tav = fv->value->tav; + LLVMValueRef val = lb_const_value(m, elem_type, tav.value, allow_local).value; + i64 offset = matrix_index_to_offset(type, index); + GB_ASSERT(values[offset] == nullptr); + values[offset] = val; + } + } + + for (i64 i = 0; i < total_count; i++) { + if (values[i] == nullptr) { + values[i] = LLVMConstNull(lb_type(m, elem_type)); + } + } + + res.value = lb_build_constant_array_values(m, type, elem_type, cast(isize)total_count, values, allow_local); + return res; + } else { + GB_ASSERT_MSG(elem_count == max_count, "%td != %td", elem_count, max_count); + + LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count); + + for_array(i, cl->elems) { + TypeAndValue tav = cl->elems[i]->tav; + GB_ASSERT(tav.mode != Addressing_Invalid); + i64 offset = matrix_index_to_offset(type, i); + values[offset] = lb_const_value(m, elem_type, tav.value, allow_local).value; + } + for (isize i = 0; i < total_count; i++) { + if (values[i] == nullptr) { + values[i] = LLVMConstNull(lb_type(m, elem_type)); + } + } + + res.value = lb_build_constant_array_values(m, type, elem_type, cast(isize)total_count, values, allow_local); + return res; + } } else { return lb_const_nil(m, original_type); } diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index bcbb77355..518ce33af 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -648,18 +648,23 @@ slow_form: i64 inner = xt->Matrix.column_count; i64 outer_columns = yt->Matrix.column_count; + auto inners = slice_make(permanent_allocator(), inner); + for (i64 j = 0; j < outer_columns; j++) { for (i64 i = 0; i < outer_rows; i++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); for (i64 k = 0; k < inner; k++) { - lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); - lbValue d0 = lb_emit_load(p, dst); - - lbValue a = lb_emit_matrix_ev(p, lhs, i, k); - lbValue b = lb_emit_matrix_ev(p, rhs, k, j); - lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem); - lbValue d = lb_emit_arith(p, Token_Add, d0, c, elem); - lb_emit_store(p, dst, d); + inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k); + inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j); } + + lbValue sum = lb_emit_load(p, dst); + for (i64 k = 0; k < inner; k++) { + lbValue a = inners[k][0]; + lbValue b = inners[k][1]; + sum = lb_emit_mul_add(p, a, b, sum, elem); + } + lb_emit_store(p, dst, sum); } } @@ -3626,6 +3631,7 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { case Type_Slice: et = bt->Slice.elem; break; case Type_BitSet: et = bt->BitSet.elem; break; case Type_SimdVector: et = bt->SimdVector.elem; break; + case Type_Matrix: et = bt->Matrix.elem; break; } String proc_name = {}; @@ -4157,7 +4163,104 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { } break; } + + case Type_Matrix: { + if (cl->elems.count > 0) { + lb_addr_store(p, v, lb_const_value(p->module, type, exact_value_compound(expr))); + + auto temp_data = array_make(temporary_allocator(), 0, cl->elems.count); + // NOTE(bill): Separate value, gep, store into their own chunks + for_array(i, cl->elems) { + Ast *elem = cl->elems[i]; + + if (elem->kind == Ast_FieldValue) { + ast_node(fv, FieldValue, elem); + if (lb_is_elem_const(fv->value, et)) { + continue; + } + if (is_ast_range(fv->field)) { + ast_node(ie, BinaryExpr, fv->field); + TypeAndValue lo_tav = ie->left->tav; + TypeAndValue hi_tav = ie->right->tav; + GB_ASSERT(lo_tav.mode == Addressing_Constant); + GB_ASSERT(hi_tav.mode == Addressing_Constant); + + TokenKind op = ie->op.kind; + i64 lo = exact_value_to_i64(lo_tav.value); + i64 hi = exact_value_to_i64(hi_tav.value); + if (op != Token_RangeHalf) { + hi += 1; + } + + lbValue value = lb_build_expr(p, fv->value); + + for (i64 k = lo; k < hi; k++) { + lbCompoundLitElemTempData data = {}; + data.value = value; + + data.elem_index = cast(i32)matrix_index_to_offset(bt, k); + array_add(&temp_data, data); + } + + } else { + auto tav = fv->field->tav; + GB_ASSERT(tav.mode == Addressing_Constant); + i64 index = exact_value_to_i64(tav.value); + + lbValue value = lb_build_expr(p, fv->value); + lbCompoundLitElemTempData data = {}; + data.value = lb_emit_conv(p, value, et); + data.expr = fv->value; + + data.elem_index = cast(i32)matrix_index_to_offset(bt, index); + array_add(&temp_data, data); + } + + } else { + if (lb_is_elem_const(elem, et)) { + continue; + } + lbCompoundLitElemTempData data = {}; + data.expr = elem; + data.elem_index = cast(i32)matrix_index_to_offset(bt, i); + array_add(&temp_data, data); + } + } + + for_array(i, temp_data) { + temp_data[i].gep = lb_emit_array_epi(p, lb_addr_get_ptr(p, v), temp_data[i].elem_index); + } + + for_array(i, temp_data) { + lbValue field_expr = temp_data[i].value; + Ast *expr = temp_data[i].expr; + + auto prev_hint = lb_set_copy_elision_hint(p, lb_addr(temp_data[i].gep), expr); + + if (field_expr.value == nullptr) { + field_expr = lb_build_expr(p, expr); + } + Type *t = field_expr.type; + GB_ASSERT(t->kind != Type_Tuple); + lbValue ev = lb_emit_conv(p, field_expr, et); + + if (!p->copy_elision_hint.used) { + temp_data[i].value = ev; + } + + lb_reset_copy_elision_hint(p, prev_hint); + } + + for_array(i, temp_data) { + if (temp_data[i].value.value != nullptr) { + lb_emit_store(p, temp_data[i].gep, temp_data[i].value); + } + } + } + break; + } + } return v; diff --git a/src/llvm_backend_type.cpp b/src/llvm_backend_type.cpp index 82e20bf60..decb57702 100644 --- a/src/llvm_backend_type.cpp +++ b/src/llvm_backend_type.cpp @@ -877,7 +877,7 @@ void lb_setup_type_info_data(lbProcedure *p) { // NOTE(bill): Setup type_info da LLVMValueRef vals[5] = { lb_get_type_info_ptr(m, t->Matrix.elem).value, lb_const_int(m, t_int, ez).value, - lb_const_int(m, t_int, matrix_type_stride(t)).value, + lb_const_int(m, t_int, matrix_type_stride_in_elems(t)).value, lb_const_int(m, t_int, t->Matrix.row_count).value, lb_const_int(m, t_int, t->Matrix.column_count).value, }; diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index c7e9e1742..fb9264661 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1225,18 +1225,53 @@ lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) { Type *t = s.type; GB_ASSERT(is_type_pointer(t)); Type *mt = base_type(type_deref(t)); - GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt)); - + Type *ptr = base_array_type(mt); - i64 stride_elems = matrix_type_stride_in_elems(mt); + if (column == 0) { + GB_ASSERT_MSG(is_type_matrix(mt) || is_type_array_like(mt), "%s", type_to_string(mt)); + + LLVMValueRef indices[2] = { + LLVMConstInt(lb_type(p->module, t_int), 0, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)row, false), + }; + + lbValue res = {}; + if (lb_is_const(s)) { + res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices)); + } else { + res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), ""); + } + + Type *ptr = base_array_type(mt); + res.type = alloc_type_pointer(ptr); + return res; + } else if (row == 0 && is_type_array_like(mt)) { + LLVMValueRef indices[2] = { + LLVMConstInt(lb_type(p->module, t_int), 0, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)column, false), + }; + + lbValue res = {}; + if (lb_is_const(s)) { + res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices)); + } else { + res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), ""); + } + + Type *ptr = base_array_type(mt); + res.type = alloc_type_pointer(ptr); + return res; + } + - isize index = row + column*stride_elems; - GB_ASSERT(0 <= index); + GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt)); + + isize offset = matrix_indices_to_offset(mt, row, column); LLVMValueRef indices[2] = { LLVMConstInt(lb_type(p->module, t_int), 0, false), - LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)index, false), + LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)offset, false), }; lbValue res = {}; @@ -1447,3 +1482,34 @@ lbValue lb_soa_struct_cap(lbProcedure *p, lbValue value) { } return lb_emit_struct_ev(p, value, cast(i32)n); } + + + +lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t) { + lbModule *m = p->module; + + a = lb_emit_conv(p, a, t); + b = lb_emit_conv(p, b, t); + c = lb_emit_conv(p, c, t); + + if (!is_type_different_to_arch_endianness(t) && is_type_float(t)) { + char const *name = "llvm.fma"; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + + LLVMTypeRef types[1] = {}; + types[0] = lb_type(m, t); + + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); + LLVMValueRef values[3] = {}; + values[0] = a.value; + values[1] = b.value; + values[2] = c.value; + LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); + return {call, t}; + } else { + lbValue x = lb_emit_arith(p, Token_Mul, a, b, t); + lbValue y = lb_emit_arith(p, Token_Add, x, c, t); + return y; + } +} \ No newline at end of file diff --git a/src/parser.cpp b/src/parser.cpp index c29cf70d9..83da481d5 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -2569,6 +2569,7 @@ bool is_literal_type(Ast *node) { case Ast_DynamicArrayType: case Ast_MapType: case Ast_BitSetType: + case Ast_MatrixType: case Ast_CallExpr: return true; } diff --git a/src/types.cpp b/src/types.cpp index ec094b4ff..bbabdf732 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1276,6 +1276,39 @@ i64 matrix_type_total_elems(Type *t) { return size/gb_max(elem_size, 1); } +void matrix_indices_from_index(Type *t, i64 index, i64 *row_index_, i64 *column_index_) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + i64 row_count = t->Matrix.row_count; + i64 column_count = t->Matrix.column_count; + GB_ASSERT(0 <= index && index < row_count*column_count); + + i64 row_index = index / column_count; + i64 column_index = index % column_count; + + if (row_index_) *row_index_ = row_index; + if (column_index_) *column_index_ = column_index; +} + +i64 matrix_index_to_offset(Type *t, i64 index) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + + i64 row_index, column_index; + matrix_indices_from_index(t, index, &row_index, &column_index); + i64 stride_elems = matrix_type_stride_in_elems(t); + return stride_elems*column_index + row_index; +} + +i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) { + t = base_type(t); + GB_ASSERT(t->kind == Type_Matrix); + GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count); + GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count); + i64 stride_elems = matrix_type_stride_in_elems(t); + return stride_elems*column_index + row_index; +} + bool is_type_dynamic_array(Type *t) { t = base_type(t); return t->kind == Type_DynamicArray; -- cgit v1.2.3 From d67d7168e2d4ed8e0e5f0d1b23aba5e5ebac6847 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Thu, 21 Oct 2021 00:04:22 +0100 Subject: Allow scalars with matrices --- src/check_expr.cpp | 8 ++++++++ src/llvm_backend_const.cpp | 2 +- src/llvm_backend_expr.cpp | 32 ++++++++++++++++++++++++-------- 3 files changed, 33 insertions(+), 9 deletions(-) (limited to 'src/llvm_backend_const.cpp') diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 8a1e5fd86..498bf78c7 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -657,6 +657,14 @@ i64 check_distance_between_types(CheckerContext *c, Operand *operand, Type *type return distance + 6; } } + + if (is_type_matrix(dst)) { + Type *elem = base_array_type(dst); + i64 distance = check_distance_between_types(c, operand, elem); + if (distance >= 0) { + return distance + 7; + } + } if (is_type_any(dst)) { if (!is_type_polymorphic(src)) { diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 413fb365b..554255f47 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -527,7 +527,7 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc i64 total_elem_count = matrix_type_total_elems(type); LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); for (i64 i = 0; i < row; i++) { - elems[matrix_index_to_offset(type, i)] = single_elem.value; + elems[matrix_indices_to_offset(type, i, i)] = single_elem.value; } for (i64 i = 0; i < total_elem_count; i++) { if (elems[i] == nullptr) { diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index beb860383..cdc1deea1 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -331,7 +331,7 @@ bool lb_try_direct_vector_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbVal z = LLVMBuildFRem(p->builder, x, y, ""); break; default: - GB_PANIC("Unsupported vector operation"); + GB_PANIC("Unsupported vector operation %.*s", LIT(token_strings[op])); break; } @@ -918,10 +918,11 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false) { GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); - Type *xt = base_type(lhs.type); - Type *yt = base_type(rhs.type); if (op == Token_Mul && !component_wise) { + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + if (xt->kind == Type_Matrix) { if (yt->kind == Type_Matrix) { return lb_emit_matrix_mul(p, lhs, rhs, type); @@ -934,21 +935,36 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue } } else { - GB_ASSERT(are_types_identical(xt, yt)); + if (is_type_matrix(lhs.type)) { + rhs = lb_emit_conv(p, rhs, lhs.type); + } else { + lhs = lb_emit_conv(p, lhs, rhs.type); + } + + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + GB_ASSERT_MSG(are_types_identical(xt, yt), "%s %.*s %s", type_to_string(lhs.type), LIT(token_strings[op]), type_to_string(rhs.type)); GB_ASSERT(xt->kind == Type_Matrix); // element-wise arithmetic // pretend it is an array lbValue array_lhs = lhs; lbValue array_rhs = rhs; Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt)); - GB_ASSERT(type_size_of(array_type) == type_size_of(type)); + GB_ASSERT(type_size_of(array_type) == type_size_of(xt)); array_lhs.type = array_type; array_rhs.type = array_type; - lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, array_type); - array.type = type; - return array; + if (token_is_comparison(op)) { + lbValue res = lb_emit_comp(p, op, array_lhs, array_rhs); + return lb_emit_conv(p, res, type); + } else { + lbValue array = lb_emit_arith(p, op, array_lhs, array_rhs, array_type); + array.type = type; + return array; + } + } GB_PANIC("TODO: lb_emit_arith_matrix"); -- cgit v1.2.3 From 48d277a3c4604481074df2914efbaba9e0dbed25 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Thu, 21 Oct 2021 01:34:39 +0100 Subject: Allow conversions between matrices of the same element count --- src/check_expr.cpp | 4 +++- src/llvm_backend_const.cpp | 4 ++-- src/llvm_backend_expr.cpp | 38 ++++++++++++++++++++++++++++---------- src/types.cpp | 34 +++++++++++++--------------------- 4 files changed, 46 insertions(+), 34 deletions(-) (limited to 'src/llvm_backend_const.cpp') diff --git a/src/check_expr.cpp b/src/check_expr.cpp index ad12e00c8..ee7493553 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -2469,7 +2469,9 @@ bool check_is_castable_to(CheckerContext *c, Operand *operand, Type *y) { } if (src->Matrix.row_count != src->Matrix.column_count) { - return false; + i64 src_count = src->Matrix.row_count*src->Matrix.column_count; + i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count; + return src_count == dst_count; } if (dst->Matrix.row_count != dst->Matrix.column_count) { diff --git a/src/llvm_backend_const.cpp b/src/llvm_backend_const.cpp index 554255f47..b543089e5 100644 --- a/src/llvm_backend_const.cpp +++ b/src/llvm_backend_const.cpp @@ -524,7 +524,7 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc lbValue single_elem = lb_const_value(m, elem, value, allow_local); single_elem.value = llvm_const_cast(single_elem.value, lb_type(m, elem)); - i64 total_elem_count = matrix_type_total_elems(type); + i64 total_elem_count = matrix_type_total_internal_elems(type); LLVMValueRef *elems = gb_alloc_array(permanent_allocator(), LLVMValueRef, cast(isize)total_elem_count); for (i64 i = 0; i < row; i++) { elems[matrix_indices_to_offset(type, i, i)] = single_elem.value; @@ -990,7 +990,7 @@ lbValue lb_const_value(lbModule *m, Type *type, ExactValue value, bool allow_loc } i64 max_count = type->Matrix.row_count*type->Matrix.column_count; - i64 total_count = matrix_type_total_elems(type); + i64 total_count = matrix_type_total_internal_elems(type); LLVMValueRef *values = gb_alloc_array(temporary_allocator(), LLVMValueRef, cast(isize)total_count); if (cl->elems[0]->kind == Ast_FieldValue) { diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 9582be93c..eb88bbde0 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -508,7 +508,7 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) { GB_ASSERT(mt->kind == Type_Matrix); LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem); - unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); + unsigned total_count = cast(unsigned)matrix_type_total_internal_elems(mt); LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value; @@ -948,7 +948,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue // pretend it is an array lbValue array_lhs = lhs; lbValue array_rhs = rhs; - Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt)); + Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_internal_elems(xt)); GB_ASSERT(type_size_of(array_type) == type_size_of(xt)); array_lhs.type = array_type; @@ -1941,15 +1941,33 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { GB_ASSERT(dst->kind == Type_Matrix); GB_ASSERT(src->kind == Type_Matrix); lbAddr v = lb_add_local_generated(p, t, true); - for (i64 j = 0; j < dst->Matrix.column_count; j++) { - for (i64 i = 0; i < dst->Matrix.row_count; i++) { - if (i < src->Matrix.row_count && j < src->Matrix.column_count) { - lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + + if (is_matrix_square(dst) && is_matrix_square(dst)) { + for (i64 j = 0; j < dst->Matrix.column_count; j++) { + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + if (i < src->Matrix.row_count && j < src->Matrix.column_count) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_emit_matrix_ev(p, value, i, j); + lb_emit_store(p, d, s); + } else if (i == j) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true); + lb_emit_store(p, d, s); + } + } + } + } else { + i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count; + i64 src_count = src->Matrix.row_count*src->Matrix.column_count; + GB_ASSERT(dst_count == src_count); + + for (i64 j = 0; j < src->Matrix.column_count; j++) { + for (i64 i = 0; i < src->Matrix.row_count; i++) { lbValue s = lb_emit_matrix_ev(p, value, i, j); - lb_emit_store(p, d, s); - } else if (i == j) { - lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); - lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true); + i64 index = i + j*src->Matrix.row_count; + i64 dst_i = index%dst->Matrix.row_count; + i64 dst_j = index/dst->Matrix.row_count; + lbValue d = lb_emit_matrix_epi(p, v.addr, dst_i, dst_j); lb_emit_store(p, d, s); } } diff --git a/src/types.cpp b/src/types.cpp index d3fa363c2..3abcebdfb 100644 --- a/src/types.cpp +++ b/src/types.cpp @@ -1293,7 +1293,7 @@ i64 matrix_type_stride_in_elems(Type *t) { } -i64 matrix_type_total_elems(Type *t) { +i64 matrix_type_total_internal_elems(Type *t) { t = base_type(t); GB_ASSERT(t->kind == Type_Matrix); i64 size = type_size_of(t); @@ -1301,37 +1301,29 @@ i64 matrix_type_total_elems(Type *t) { return size/gb_max(elem_size, 1); } -void matrix_indices_from_index(Type *t, i64 index, i64 *row_index_, i64 *column_index_) { +i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) { t = base_type(t); GB_ASSERT(t->kind == Type_Matrix); - i64 row_count = t->Matrix.row_count; - i64 column_count = t->Matrix.column_count; - GB_ASSERT(0 <= index && index < row_count*column_count); - - i64 row_index = index / column_count; - i64 column_index = index % column_count; - - if (row_index_) *row_index_ = row_index; - if (column_index_) *column_index_ = column_index; + GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count); + GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count); + i64 stride_elems = matrix_type_stride_in_elems(t); + return stride_elems*column_index + row_index; } - i64 matrix_index_to_offset(Type *t, i64 index) { t = base_type(t); GB_ASSERT(t->kind == Type_Matrix); - i64 row_index, column_index; - matrix_indices_from_index(t, index, &row_index, &column_index); - i64 stride_elems = matrix_type_stride_in_elems(t); - return stride_elems*column_index + row_index; + i64 row_index = index%t->Matrix.row_count; + i64 column_index = index/t->Matrix.row_count; + return matrix_indices_to_offset(t, row_index, column_index); } -i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) { + + +bool is_matrix_square(Type *t) { t = base_type(t); GB_ASSERT(t->kind == Type_Matrix); - GB_ASSERT(0 <= row_index && row_index < t->Matrix.row_count); - GB_ASSERT(0 <= column_index && column_index < t->Matrix.column_count); - i64 stride_elems = matrix_type_stride_in_elems(t); - return stride_elems*column_index + row_index; + return t->Matrix.row_count == t->Matrix.column_count; } bool is_type_valid_for_matrix_elems(Type *t) { -- cgit v1.2.3