diff options
Diffstat (limited to 'src/llvm_backend_expr.cpp')
| -rw-r--r-- | src/llvm_backend_expr.cpp | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index d41a0a127..27f12a829 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -522,9 +522,41 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { } } return lb_addr_load(p, res); +} + + +lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) { + Type *mt = base_type(type); + Type *at = base_type(a.type); + Type *bt = base_type(b.type); + GB_ASSERT(mt->kind == Type_Matrix); + GB_ASSERT(at->kind == Type_Array); + GB_ASSERT(bt->kind == Type_Array); + + + i64 row_count = mt->Matrix.row_count; + i64 column_count = mt->Matrix.column_count; + + GB_ASSERT(row_count == at->Array.count); + GB_ASSERT(column_count == bt->Array.count); + + + lbAddr res = lb_add_local_generated(p, type, true); + + for (i64 j = 0; j < column_count; j++) { + for (i64 i = 0; i < row_count; i++) { + lbValue x = lb_emit_struct_ev(p, a, cast(i32)i); + lbValue y = lb_emit_struct_ev(p, b, cast(i32)j); + lbValue src = lb_emit_arith(p, Token_Mul, x, y, mt->Matrix.elem); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + lb_emit_store(p, dst, src); + } + } + return lb_addr_load(p, res); } + lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); |