diff options
| author | gingerBill <gingerBill@users.noreply.github.com> | 2021-10-26 21:08:08 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2021-10-26 21:08:08 +0100 |
| commit | c4d2aae0ed55d972b0074031ac82db6f9546447e (patch) | |
| tree | c23fe528ddaee43ea2c9ecfd5b95d93ef7fba467 /src/llvm_backend_expr.cpp | |
| parent | c722665c3239019fe9f90d247726cc42c921e1db (diff) | |
| parent | 549a383cf06ad45edd634e67c27a1246323a9d8c (diff) | |
Merge pull request #1245 from odin-lang/new-matrix-type
`matrix` type
Diffstat (limited to 'src/llvm_backend_expr.cpp')
| -rw-r--r-- | src/llvm_backend_expr.cpp | 749 |
1 files changed, 748 insertions, 1 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 6737c97bc..c9827ae3a 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -331,7 +331,7 @@ bool lb_try_direct_vector_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbVal z = LLVMBuildFRem(p->builder, x, y, ""); break; default: - GB_PANIC("Unsupported vector operation"); + GB_PANIC("Unsupported vector operation %.*s", LIT(token_strings[op])); break; } @@ -476,11 +476,545 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r } } +bool lb_is_matrix_simdable(Type *t) { + Type *mt = base_type(t); + GB_ASSERT(mt->kind == Type_Matrix); + + Type *elem = core_type(mt->Matrix.elem); + if (is_type_complex(elem)) { + return false; + } + + if (is_type_different_to_arch_endianness(elem)) { + return false; + } + + switch (build_context.metrics.arch) { + case TargetArch_amd64: + case TargetArch_arm64: + // possible + break; + case TargetArch_386: + case TargetArch_wasm32: + // nope + return false; + } + + if (elem->kind == Type_Basic) { + switch (elem->Basic.kind) { + case Basic_f16: + case Basic_f16le: + case Basic_f16be: + switch (build_context.metrics.arch) { + case TargetArch_amd64: + return false; + case TargetArch_arm64: + // TODO(bill): determine when this is fine + return true; + case TargetArch_386: + case TargetArch_wasm32: + return false; + } + } + } + + return true; +} + + +LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) { + Type *mt = base_type(matrix.type); + GB_ASSERT(mt->kind == Type_Matrix); + LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem); + + unsigned total_count = cast(unsigned)matrix_type_total_internal_elems(mt); + LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); + +#if 1 + LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value; + LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), ""); + LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, ""); + LLVMSetAlignment(matrix_vector, cast(unsigned)type_align_of(mt)); + return matrix_vector; +#else + LLVMValueRef matrix_vector = LLVMBuildBitCast(p->builder, matrix.value, total_matrix_type, ""); + return matrix_vector; +#endif +} + +LLVMValueRef lb_matrix_trimmed_vector_mask(lbProcedure *p, Type *mt) { + mt = base_type(mt); + GB_ASSERT(mt->kind == Type_Matrix); + + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + unsigned row_count = cast(unsigned)mt->Matrix.row_count; + unsigned column_count = cast(unsigned)mt->Matrix.column_count; + unsigned mask_elems_index = 0; + auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count*column_count); + for (unsigned j = 0; j < column_count; j++) { + for (unsigned i = 0; i < row_count; i++) { + unsigned offset = stride*j + i; + mask_elems[mask_elems_index++] = lb_const_int(p->module, t_u32, offset).value; + } + } + + LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count); + return mask; +} + +LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) { + LLVMValueRef vector = lb_matrix_to_vector(p, m); + + Type *mt = base_type(m.type); + GB_ASSERT(mt->kind == Type_Matrix); + + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + unsigned row_count = cast(unsigned)mt->Matrix.row_count; + if (stride == row_count) { + return vector; + } + + LLVMValueRef mask = lb_matrix_trimmed_vector_mask(p, mt); + LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, ""); + return trimmed_vector; +} + + +lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { + if (is_type_array(m.type)) { + // no-op + m.type = type; + return m; + } + Type *mt = base_type(m.type); + GB_ASSERT(mt->kind == Type_Matrix); + + if (lb_is_matrix_simdable(mt)) { + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + unsigned row_count = cast(unsigned)mt->Matrix.row_count; + unsigned column_count = cast(unsigned)mt->Matrix.column_count; + + auto rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count); + auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count); + + LLVMValueRef vector = lb_matrix_to_vector(p, m); + for (unsigned i = 0; i < row_count; i++) { + for (unsigned j = 0; j < column_count; j++) { + unsigned offset = stride*j + i; + mask_elems[j] = lb_const_int(p->module, t_u32, offset).value; + } + + // transpose mask + LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count); + LLVMValueRef row = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, ""); + rows[i] = row; + } + + lbAddr res = lb_add_local_generated(p, type, true); + for_array(i, rows) { + LLVMValueRef row = rows[i]; + lbValue dst_row_ptr = lb_emit_matrix_epi(p, res.addr, 0, i); + LLVMValueRef ptr = dst_row_ptr.value; + ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(row), 0), ""); + LLVMBuildStore(p->builder, row, ptr); + } + + return lb_addr_load(p, res); + } + + lbAddr res = lb_add_local_generated(p, type, true); + + i64 row_count = mt->Matrix.row_count; + i64 column_count = mt->Matrix.column_count; + for (i64 j = 0; j < column_count; j++) { + for (i64 i = 0; i < row_count; i++) { + lbValue src = lb_emit_matrix_ev(p, m, i, j); + lbValue dst = lb_emit_matrix_epi(p, res.addr, j, i); + lb_emit_store(p, dst, src); + } + } + return lb_addr_load(p, res); +} + +lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) { + lbAddr res = lb_add_local_generated(p, type, true); + LLVMValueRef res_ptr = res.addr.value; + unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector))); + LLVMSetAlignment(res_ptr, alignment); + + res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), ""); + LLVMBuildStore(p->builder, vector, res_ptr); + + return lb_addr_load(p, res); +} + +lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) { + if (is_type_array(m.type)) { + // no-op + m.type = type; + return m; + } + Type *mt = base_type(m.type); + GB_ASSERT(mt->kind == Type_Matrix); + + if (lb_is_matrix_simdable(mt)) { + LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m); + return lb_matrix_cast_vector_to_type(p, vector, type); + } + + lbAddr res = lb_add_local_generated(p, type, true); + + i64 row_count = mt->Matrix.row_count; + i64 column_count = mt->Matrix.column_count; + for (i64 j = 0; j < column_count; j++) { + for (i64 i = 0; i < row_count; i++) { + lbValue src = lb_emit_matrix_ev(p, m, i, j); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + lb_emit_store(p, dst, src); + } + } + return lb_addr_load(p, res); +} + + +lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) { + Type *mt = base_type(type); + Type *at = base_type(a.type); + Type *bt = base_type(b.type); + GB_ASSERT(mt->kind == Type_Matrix); + GB_ASSERT(at->kind == Type_Array); + GB_ASSERT(bt->kind == Type_Array); + + + i64 row_count = mt->Matrix.row_count; + i64 column_count = mt->Matrix.column_count; + + GB_ASSERT(row_count == at->Array.count); + GB_ASSERT(column_count == bt->Array.count); + + + lbAddr res = lb_add_local_generated(p, type, true); + + for (i64 j = 0; j < column_count; j++) { + for (i64 i = 0; i < row_count; i++) { + lbValue x = lb_emit_struct_ev(p, a, cast(i32)i); + lbValue y = lb_emit_struct_ev(p, b, cast(i32)j); + lbValue src = lb_emit_arith(p, Token_Mul, x, y, mt->Matrix.elem); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + lb_emit_store(p, dst, src); + } + } + return lb_addr_load(p, res); + +} + +lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { + // TODO(bill): Handle edge case for f16 types on x86(-64) platforms + + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + GB_ASSERT(is_type_matrix(type)); + GB_ASSERT(is_type_matrix(xt)); + GB_ASSERT(is_type_matrix(yt)); + GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); + GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); + + Type *elem = xt->Matrix.elem; + + unsigned outer_rows = cast(unsigned)xt->Matrix.row_count; + unsigned inner = cast(unsigned)xt->Matrix.column_count; + unsigned outer_columns = cast(unsigned)yt->Matrix.column_count; + + if (lb_is_matrix_simdable(xt)) { + unsigned x_stride = cast(unsigned)matrix_type_stride_in_elems(xt); + unsigned y_stride = cast(unsigned)matrix_type_stride_in_elems(yt); + + auto x_rows = slice_make<LLVMValueRef>(permanent_allocator(), outer_rows); + auto y_columns = slice_make<LLVMValueRef>(permanent_allocator(), outer_columns); + + LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs); + LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs); + + auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), inner); + for (unsigned i = 0; i < outer_rows; i++) { + for (unsigned j = 0; j < inner; j++) { + unsigned offset = x_stride*j + i; + mask_elems[j] = lb_const_int(p->module, t_u32, offset).value; + } + + // transpose mask + LLVMValueRef mask = LLVMConstVector(mask_elems.data, inner); + LLVMValueRef row = LLVMBuildShuffleVector(p->builder, x_vector, LLVMGetUndef(LLVMTypeOf(x_vector)), mask, ""); + x_rows[i] = row; + } + + for (unsigned i = 0; i < outer_columns; i++) { + LLVMValueRef mask = llvm_mask_iota(p->module, y_stride*i, inner); + LLVMValueRef column = LLVMBuildShuffleVector(p->builder, y_vector, LLVMGetUndef(LLVMTypeOf(y_vector)), mask, ""); + y_columns[i] = column; + } + + lbAddr res = lb_add_local_generated(p, type, true); + for_array(i, x_rows) { + LLVMValueRef x_row = x_rows[i]; + for_array(j, y_columns) { + LLVMValueRef y_column = y_columns[j]; + LLVMValueRef elem = llvm_vector_dot(p, x_row, y_column); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + LLVMBuildStore(p->builder, elem, dst.value); + } + } + return lb_addr_load(p, res); + } + + { + lbAddr res = lb_add_local_generated(p, type, true); + + auto inners = slice_make<lbValue[2]>(permanent_allocator(), inner); + + for (unsigned j = 0; j < outer_columns; j++) { + for (unsigned i = 0; i < outer_rows; i++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + for (unsigned k = 0; k < inner; k++) { + inners[k][0] = lb_emit_matrix_ev(p, lhs, i, k); + inners[k][1] = lb_emit_matrix_ev(p, rhs, k, j); + } + + lbValue sum = lb_const_nil(p->module, elem); + for (unsigned k = 0; k < inner; k++) { + lbValue a = inners[k][0]; + lbValue b = inners[k][1]; + sum = lb_emit_mul_add(p, a, b, sum, elem); + } + lb_emit_store(p, dst, sum); + } + } + + return lb_addr_load(p, res); + } +} + +lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { + // TODO(bill): Handle edge case for f16 types on x86(-64) platforms + + Type *mt = base_type(lhs.type); + Type *vt = base_type(rhs.type); + + GB_ASSERT(is_type_matrix(mt)); + GB_ASSERT(is_type_array_like(vt)); + + i64 vector_count = get_array_type_count(vt); + + GB_ASSERT(mt->Matrix.column_count == vector_count); + GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt))); + + Type *elem = mt->Matrix.elem; + + if (lb_is_matrix_simdable(mt)) { + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + + unsigned row_count = cast(unsigned)mt->Matrix.row_count; + unsigned column_count = cast(unsigned)mt->Matrix.column_count; + auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count); + auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count); + + LLVMValueRef matrix_vector = lb_matrix_to_vector(p, lhs); + + for (unsigned column_index = 0; column_index < column_count; column_index++) { + LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count); + LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, ""); + m_columns[column_index] = column; + } + + for (unsigned row_index = 0; row_index < column_count; row_index++) { + LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value; + LLVMValueRef row = llvm_vector_broadcast(p, value, row_count); + v_rows[row_index] = row; + } + + GB_ASSERT(column_count > 0); + + LLVMValueRef vector = nullptr; + for (i64 i = 0; i < column_count; i++) { + if (i == 0) { + vector = llvm_vector_mul(p, m_columns[i], v_rows[i]); + } else { + vector = llvm_vector_mul_add(p, m_columns[i], v_rows[i], vector); + } + } + + return lb_matrix_cast_vector_to_type(p, vector, type); + } + + lbAddr res = lb_add_local_generated(p, type, true); + + for (i64 i = 0; i < mt->Matrix.row_count; i++) { + for (i64 j = 0; j < mt->Matrix.column_count; j++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, 0); + lbValue d0 = lb_emit_load(p, dst); + + lbValue a = lb_emit_matrix_ev(p, lhs, i, j); + lbValue b = lb_emit_struct_ev(p, rhs, cast(i32)j); + lbValue c = lb_emit_mul_add(p, a, b, d0, elem); + lb_emit_store(p, dst, c); + } + } + + return lb_addr_load(p, res); +} + +lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { + // TODO(bill): Handle edge case for f16 types on x86(-64) platforms + + Type *mt = base_type(rhs.type); + Type *vt = base_type(lhs.type); + + GB_ASSERT(is_type_matrix(mt)); + GB_ASSERT(is_type_array_like(vt)); + + i64 vector_count = get_array_type_count(vt); + + GB_ASSERT(vector_count == mt->Matrix.row_count); + GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt))); + + Type *elem = mt->Matrix.elem; + + if (lb_is_matrix_simdable(mt)) { + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + + unsigned row_count = cast(unsigned)mt->Matrix.row_count; + unsigned column_count = cast(unsigned)mt->Matrix.column_count; gb_unused(column_count); + auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), row_count); + auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), row_count); + + LLVMValueRef matrix_vector = lb_matrix_to_vector(p, rhs); + + auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), column_count); + for (unsigned row_index = 0; row_index < row_count; row_index++) { + for (unsigned column_index = 0; column_index < column_count; column_index++) { + unsigned offset = row_index + column_index*stride; + mask_elems[column_index] = lb_const_int(p->module, t_u32, offset).value; + } + + // transpose mask + LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count); + LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, ""); + m_columns[row_index] = column; + } + + for (unsigned column_index = 0; column_index < row_count; column_index++) { + LLVMValueRef value = lb_emit_struct_ev(p, lhs, column_index).value; + LLVMValueRef row = llvm_vector_broadcast(p, value, column_count); + v_rows[column_index] = row; + } + + GB_ASSERT(row_count > 0); + + LLVMValueRef vector = nullptr; + for (i64 i = 0; i < row_count; i++) { + if (i == 0) { + vector = llvm_vector_mul(p, v_rows[i], m_columns[i]); + } else { + vector = llvm_vector_mul_add(p, v_rows[i], m_columns[i], vector); + } + } + + lbAddr res = lb_add_local_generated(p, type, true); + LLVMValueRef res_ptr = res.addr.value; + unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector))); + LLVMSetAlignment(res_ptr, alignment); + + res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), ""); + LLVMBuildStore(p->builder, vector, res_ptr); + + return lb_addr_load(p, res); + } + + lbAddr res = lb_add_local_generated(p, type, true); + + for (i64 j = 0; j < mt->Matrix.column_count; j++) { + for (i64 k = 0; k < mt->Matrix.row_count; k++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, 0, j); + lbValue d0 = lb_emit_load(p, dst); + + lbValue a = lb_emit_struct_ev(p, lhs, cast(i32)k); + lbValue b = lb_emit_matrix_ev(p, rhs, k, j); + lbValue c = lb_emit_mul_add(p, a, b, d0, elem); + lb_emit_store(p, dst, c); + } + } + + return lb_addr_load(p, res); +} + + + + +lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false) { + GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); + + + if (op == Token_Mul && !component_wise) { + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + if (xt->kind == Type_Matrix) { + if (yt->kind == Type_Matrix) { + return lb_emit_matrix_mul(p, lhs, rhs, type); + } else if (is_type_array_like(yt)) { + return lb_emit_matrix_mul_vector(p, lhs, rhs, type); + } + } else if (is_type_array_like(xt)) { + GB_ASSERT(yt->kind == Type_Matrix); + return lb_emit_vector_mul_matrix(p, lhs, rhs, type); + } + + } else { + if (is_type_matrix(lhs.type)) { + rhs = lb_emit_conv(p, rhs, lhs.type); + } else { + lhs = lb_emit_conv(p, lhs, rhs.type); + } + + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + GB_ASSERT_MSG(are_types_identical(xt, yt), "%s %.*s %s", type_to_string(lhs.type), LIT(token_strings[op]), type_to_string(rhs.type)); + GB_ASSERT(xt->kind == Type_Matrix); + // element-wise arithmetic + // pretend it is an array + lbValue array_lhs = lhs; + lbValue array_rhs = rhs; + Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_internal_elems(xt)); + GB_ASSERT(type_size_of(array_type) == type_size_of(xt)); + + array_lhs.type = array_type; + array_rhs.type = array_type; + + if (token_is_comparison(op)) { + lbValue res = lb_emit_comp(p, op, array_lhs, array_rhs); + return lb_emit_conv(p, res, type); + } else { + lbValue array = lb_emit_arith(p, op, array_lhs, array_rhs, array_type); + array.type = type; + return array; + } + + } + + GB_PANIC("TODO: lb_emit_arith_matrix"); + + return {}; +} + lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) { return lb_emit_arith_array(p, op, lhs, rhs, type); + } else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) { + return lb_emit_arith_matrix(p, op, lhs, rhs, type); } else if (is_type_complex(type)) { lhs = lb_emit_conv(p, lhs, type); rhs = lb_emit_conv(p, rhs, type); @@ -749,6 +1283,13 @@ lbValue lb_build_binary_expr(lbProcedure *p, Ast *expr) { ast_node(be, BinaryExpr, expr); TypeAndValue tv = type_and_value_of_expr(expr); + + if (is_type_matrix(be->left->tav.type) || is_type_matrix(be->right->tav.type)) { + lbValue left = lb_build_expr(p, be->left); + lbValue right = lb_build_expr(p, be->right); + return lb_emit_arith_matrix(p, be->op.kind, left, right, default_type(tv.type)); + } + switch (be->op.kind) { case Token_Add: @@ -1417,6 +1958,62 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { } return lb_addr_load(p, v); } + + if (is_type_matrix(dst) && !is_type_matrix(src)) { + GB_ASSERT_MSG(dst->Matrix.row_count == dst->Matrix.column_count, "%s <- %s", type_to_string(dst), type_to_string(src)); + + Type *elem = base_array_type(dst); + lbValue e = lb_emit_conv(p, value, elem); + lbAddr v = lb_add_local_generated(p, t, false); + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + isize j = cast(isize)i; + lbValue ptr = lb_emit_matrix_epi(p, v.addr, j, j); + lb_emit_store(p, ptr, e); + } + + + return lb_addr_load(p, v); + } + + if (is_type_matrix(dst) && is_type_matrix(src)) { + GB_ASSERT(dst->kind == Type_Matrix); + GB_ASSERT(src->kind == Type_Matrix); + lbAddr v = lb_add_local_generated(p, t, true); + + if (is_matrix_square(dst) && is_matrix_square(dst)) { + for (i64 j = 0; j < dst->Matrix.column_count; j++) { + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + if (i < src->Matrix.row_count && j < src->Matrix.column_count) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_emit_matrix_ev(p, value, i, j); + lb_emit_store(p, d, s); + } else if (i == j) { + lbValue d = lb_emit_matrix_epi(p, v.addr, i, j); + lbValue s = lb_const_value(p->module, dst->Matrix.elem, exact_value_i64(1), true); + lb_emit_store(p, d, s); + } + } + } + } else { + i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count; + i64 src_count = src->Matrix.row_count*src->Matrix.column_count; + GB_ASSERT(dst_count == src_count); + + for (i64 j = 0; j < src->Matrix.column_count; j++) { + for (i64 i = 0; i < src->Matrix.row_count; i++) { + lbValue s = lb_emit_matrix_ev(p, value, i, j); + i64 index = i + j*src->Matrix.row_count; + i64 dst_i = index%dst->Matrix.row_count; + i64 dst_j = index/dst->Matrix.row_count; + lbValue d = lb_emit_matrix_epi(p, v.addr, dst_i, dst_j); + lb_emit_store(p, d, s); + } + } + } + return lb_addr_load(p, v); + } + + if (is_type_any(dst)) { if (is_type_untyped_nil(src)) { @@ -2481,6 +3078,10 @@ lbValue lb_build_expr(lbProcedure *p, Ast *expr) { case_ast_node(ie, IndexExpr, expr); return lb_addr_load(p, lb_build_addr(p, expr)); case_end; + + case_ast_node(ie, MatrixIndexExpr, expr); + return lb_addr_load(p, lb_build_addr(p, expr)); + case_end; case_ast_node(ia, InlineAsmExpr, expr); Type *t = type_of_expr(expr); @@ -2976,6 +3577,25 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { lbValue v = lb_emit_ptr_offset(p, elem, index); return lb_addr(v); } + + case Type_Matrix: { + lbValue matrix = {}; + matrix = lb_build_addr_ptr(p, ie->expr); + if (deref) { + matrix = lb_emit_load(p, matrix); + } + lbValue index = lb_build_expr(p, ie->index); + index = lb_emit_conv(p, index, t_int); + lbValue elem = lb_emit_matrix_ep(p, matrix, lb_const_int(p->module, t_int, 0), index); + elem = lb_emit_conv(p, elem, alloc_type_pointer(type_of_expr(expr))); + + auto index_tv = type_and_value_of_expr(ie->index); + if (index_tv.mode != Addressing_Constant) { + lbValue len = lb_const_int(p->module, t_int, t->Matrix.column_count); + lb_emit_bounds_check(p, ast_token(ie->index), index, len); + } + return lb_addr(elem); + } case Type_Basic: { // Basic_string @@ -2998,6 +3618,35 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { } } case_end; + + case_ast_node(ie, MatrixIndexExpr, expr); + Type *t = base_type(type_of_expr(ie->expr)); + + bool deref = is_type_pointer(t); + t = base_type(type_deref(t)); + + lbValue m = {}; + m = lb_build_addr_ptr(p, ie->expr); + if (deref) { + m = lb_emit_load(p, m); + } + lbValue row_index = lb_build_expr(p, ie->row_index); + lbValue column_index = lb_build_expr(p, ie->column_index); + row_index = lb_emit_conv(p, row_index, t_int); + column_index = lb_emit_conv(p, column_index, t_int); + lbValue elem = lb_emit_matrix_ep(p, m, row_index, column_index); + + auto row_index_tv = type_and_value_of_expr(ie->row_index); + auto column_index_tv = type_and_value_of_expr(ie->column_index); + if (row_index_tv.mode != Addressing_Constant || column_index_tv.mode != Addressing_Constant) { + lbValue row_count = lb_const_int(p->module, t_int, t->Matrix.row_count); + lbValue column_count = lb_const_int(p->module, t_int, t->Matrix.column_count); + lb_emit_matrix_bounds_check(p, ast_token(ie->row_index), row_index, column_index, row_count, column_count); + } + return lb_addr(elem); + + + case_end; case_ast_node(se, SliceExpr, expr); @@ -3246,6 +3895,7 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { case Type_Slice: et = bt->Slice.elem; break; case Type_BitSet: et = bt->BitSet.elem; break; case Type_SimdVector: et = bt->SimdVector.elem; break; + case Type_Matrix: et = bt->Matrix.elem; break; } String proc_name = {}; @@ -3777,7 +4427,104 @@ lbAddr lb_build_addr(lbProcedure *p, Ast *expr) { } break; } + + case Type_Matrix: { + if (cl->elems.count > 0) { + lb_addr_store(p, v, lb_const_value(p->module, type, exact_value_compound(expr))); + auto temp_data = array_make<lbCompoundLitElemTempData>(temporary_allocator(), 0, cl->elems.count); + + // NOTE(bill): Separate value, gep, store into their own chunks + for_array(i, cl->elems) { + Ast *elem = cl->elems[i]; + + if (elem->kind == Ast_FieldValue) { + ast_node(fv, FieldValue, elem); + if (lb_is_elem_const(fv->value, et)) { + continue; + } + if (is_ast_range(fv->field)) { + ast_node(ie, BinaryExpr, fv->field); + TypeAndValue lo_tav = ie->left->tav; + TypeAndValue hi_tav = ie->right->tav; + GB_ASSERT(lo_tav.mode == Addressing_Constant); + GB_ASSERT(hi_tav.mode == Addressing_Constant); + + TokenKind op = ie->op.kind; + i64 lo = exact_value_to_i64(lo_tav.value); + i64 hi = exact_value_to_i64(hi_tav.value); + if (op != Token_RangeHalf) { + hi += 1; + } + + lbValue value = lb_build_expr(p, fv->value); + + for (i64 k = lo; k < hi; k++) { + lbCompoundLitElemTempData data = {}; + data.value = value; + + data.elem_index = cast(i32)matrix_index_to_offset(bt, k); + array_add(&temp_data, data); + } + + } else { + auto tav = fv->field->tav; + GB_ASSERT(tav.mode == Addressing_Constant); + i64 index = exact_value_to_i64(tav.value); + + lbValue value = lb_build_expr(p, fv->value); + lbCompoundLitElemTempData data = {}; + data.value = lb_emit_conv(p, value, et); + data.expr = fv->value; + + data.elem_index = cast(i32)matrix_index_to_offset(bt, index); + array_add(&temp_data, data); + } + + } else { + if (lb_is_elem_const(elem, et)) { + continue; + } + lbCompoundLitElemTempData data = {}; + data.expr = elem; + data.elem_index = cast(i32)matrix_index_to_offset(bt, i); + array_add(&temp_data, data); + } + } + + for_array(i, temp_data) { + temp_data[i].gep = lb_emit_array_epi(p, lb_addr_get_ptr(p, v), temp_data[i].elem_index); + } + + for_array(i, temp_data) { + lbValue field_expr = temp_data[i].value; + Ast *expr = temp_data[i].expr; + + auto prev_hint = lb_set_copy_elision_hint(p, lb_addr(temp_data[i].gep), expr); + + if (field_expr.value == nullptr) { + field_expr = lb_build_expr(p, expr); + } + Type *t = field_expr.type; + GB_ASSERT(t->kind != Type_Tuple); + lbValue ev = lb_emit_conv(p, field_expr, et); + + if (!p->copy_elision_hint.used) { + temp_data[i].value = ev; + } + + lb_reset_copy_elision_hint(p, prev_hint); + } + + for_array(i, temp_data) { + if (temp_data[i].value.value != nullptr) { + lb_emit_store(p, temp_data[i].gep, temp_data[i].value); + } + } + } + break; + } + } return v; |