diff options
Diffstat (limited to 'src/llvm_backend_expr.cpp')
| -rw-r--r-- | src/llvm_backend_expr.cpp | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 3056952f6..6b7d90ec0 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -477,10 +477,72 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r } +lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { + GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); + + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + if (op == Token_Mul) { + if (xt->kind == Type_Matrix) { + if (yt->kind == Type_Matrix) { + GB_ASSERT(is_type_matrix(type)); + GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); + GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); + + Type *elem = xt->Matrix.elem; + + lbAddr res = lb_add_local_generated(p, type, true); + for (i64 i = 0; i < xt->Matrix.row_count; i++) { + for (i64 j = 0; j < yt->Matrix.column_count; j++) { + for (i64 k = 0; k < xt->Matrix.column_count; k++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + + lbValue a = lb_emit_matrix_ev(p, lhs, i, k); + lbValue b = lb_emit_matrix_ev(p, rhs, k, j); + lbValue c = lb_emit_arith(p, op, a, b, elem); + lbValue d = lb_emit_load(p, dst); + lbValue e = lb_emit_arith(p, Token_Add, d, c, elem); + lb_emit_store(p, dst, e); + + } + } + } + + return lb_addr_load(p, res); + } + } + + } else { + GB_ASSERT(are_types_identical(xt, yt)); + GB_ASSERT(xt->kind == Type_Matrix); + // element-wise arithmetic + // pretend it is an array + lbValue array_lhs = lhs; + lbValue array_rhs = rhs; + Type *array_type = alloc_type_array(xt->Matrix.elem, matrix_type_total_elems(xt)); + GB_ASSERT(type_size_of(array_type) == type_size_of(type)); + + array_lhs.type = array_type; + array_rhs.type = array_type; + + lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, type); + array.type = type; + return array; + } + + GB_PANIC("TODO: lb_emit_arith_matrix"); + + return {}; +} + + lbValue lb_emit_arith(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { if (is_type_array_like(lhs.type) || is_type_array_like(rhs.type)) { return lb_emit_arith_array(p, op, lhs, rhs, type); + } else if (is_type_matrix(lhs.type) || is_type_matrix(rhs.type)) { + return lb_emit_arith_matrix(p, op, lhs, rhs, type); } else if (is_type_complex(type)) { lhs = lb_emit_conv(p, lhs, type); rhs = lb_emit_conv(p, rhs, type); @@ -1417,6 +1479,22 @@ lbValue lb_emit_conv(lbProcedure *p, lbValue value, Type *t) { } return lb_addr_load(p, v); } + + if (is_type_matrix(dst) && !is_type_matrix(src)) { + GB_ASSERT(dst->Matrix.row_count == dst->Matrix.column_count); + + Type *elem = base_array_type(dst); + lbValue e = lb_emit_conv(p, value, elem); + lbAddr v = lb_add_local_generated(p, t, false); + for (i64 i = 0; i < dst->Matrix.row_count; i++) { + isize j = cast(isize)i; + lbValue ptr = lb_emit_matrix_epi(p, v.addr, j, j); + lb_emit_store(p, ptr, e); + } + + + return lb_addr_load(p, v); + } if (is_type_any(dst)) { if (is_type_untyped_nil(src)) { |