aboutsummaryrefslogtreecommitdiff
path: root/src/llvm_backend_expr.cpp
diff options
context:
space:
mode:
authorgingerBill <bill@gingerbill.org>2021-10-20 02:06:56 +0100
committergingerBill <bill@gingerbill.org>2021-10-20 02:06:56 +0100
commit68afbb37f40b10fd01dda9e5640cc7ae2535a371 (patch)
tree57aacd7b2e073e077d68ef3143e15055b72b0198 /src/llvm_backend_expr.cpp
parent7faca7066c30d6e663b268dc1e8ec66710ae3dd5 (diff)
Add builtin `outer_product`
Diffstat (limited to 'src/llvm_backend_expr.cpp')
-rw-r--r--src/llvm_backend_expr.cpp32
1 files changed, 32 insertions, 0 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp
index d41a0a127..27f12a829 100644
--- a/src/llvm_backend_expr.cpp
+++ b/src/llvm_backend_expr.cpp
@@ -522,9 +522,41 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
}
}
return lb_addr_load(p, res);
+}
+
+
+lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
+ Type *mt = base_type(type);
+ Type *at = base_type(a.type);
+ Type *bt = base_type(b.type);
+ GB_ASSERT(mt->kind == Type_Matrix);
+ GB_ASSERT(at->kind == Type_Array);
+ GB_ASSERT(bt->kind == Type_Array);
+
+
+ i64 row_count = mt->Matrix.row_count;
+ i64 column_count = mt->Matrix.column_count;
+
+ GB_ASSERT(row_count == at->Array.count);
+ GB_ASSERT(column_count == bt->Array.count);
+
+
+ lbAddr res = lb_add_local_generated(p, type, true);
+
+ for (i64 j = 0; j < column_count; j++) {
+ for (i64 i = 0; i < row_count; i++) {
+ lbValue x = lb_emit_struct_ev(p, a, cast(i32)i);
+ lbValue y = lb_emit_struct_ev(p, b, cast(i32)j);
+ lbValue src = lb_emit_arith(p, Token_Mul, x, y, mt->Matrix.elem);
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+ lb_emit_store(p, dst, src);
+ }
+ }
+ return lb_addr_load(p, res);
}
+
lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
Type *xt = base_type(lhs.type);
Type *yt = base_type(rhs.type);