diff options
| author | gingerBill <bill@gingerbill.org> | 2021-10-20 02:06:56 +0100 |
|---|---|---|
| committer | gingerBill <bill@gingerbill.org> | 2021-10-20 02:06:56 +0100 |
| commit | 68afbb37f40b10fd01dda9e5640cc7ae2535a371 (patch) | |
| tree | 57aacd7b2e073e077d68ef3143e15055b72b0198 /src | |
| parent | 7faca7066c30d6e663b268dc1e8ec66710ae3dd5 (diff) | |
Add builtin `outer_product`
Diffstat (limited to 'src')
| -rw-r--r-- | src/check_builtin.cpp | 60 | ||||
| -rw-r--r-- | src/checker_builtin_procs.hpp | 2 | ||||
| -rw-r--r-- | src/llvm_backend_expr.cpp | 32 | ||||
| -rw-r--r-- | src/llvm_backend_proc.cpp | 8 |
4 files changed, 102 insertions, 0 deletions
diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp index 659a74ad7..1d033932f 100644 --- a/src/check_builtin.cpp +++ b/src/check_builtin.cpp @@ -2017,6 +2017,66 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32 operand->type = check_matrix_type_hint(operand->type, type_hint); break; } + + case BuiltinProc_outer_product: { + Operand x = {}; + Operand y = {}; + check_expr(c, &x, ce->args[0]); + if (x.mode == Addressing_Invalid) { + return false; + } + check_expr(c, &y, ce->args[1]); + if (y.mode == Addressing_Invalid) { + return false; + } + if (!is_operand_value(x) || !is_operand_value(y)) { + error(call, "'%.*s' expects only arrays", LIT(builtin_name)); + return false; + } + + if (!is_type_array(x.type) && !is_type_array(y.type)) { + gbString s1 = type_to_string(x.type); + gbString s2 = type_to_string(y.type); + error(call, "'%.*s' expects only arrays, got %s and %s", LIT(builtin_name), s1, s2); + gb_string_free(s2); + gb_string_free(s1); + return false; + } + + Type *xt = base_type(x.type); + Type *yt = base_type(y.type); + GB_ASSERT(xt->kind == Type_Array); + GB_ASSERT(yt->kind == Type_Array); + if (!are_types_identical(xt->Array.elem, yt->Array.elem)) { + gbString s1 = type_to_string(xt->Array.elem); + gbString s2 = type_to_string(yt->Array.elem); + error(call, "'%.*s' mismatched element types, got %s vs %s", LIT(builtin_name), s1, s2); + gb_string_free(s2); + gb_string_free(s1); + return false; + } + + if (xt->Array.count == 0 || yt->Array.count == 0) { + gbString s1 = type_to_string(x.type); + gbString s2 = type_to_string(y.type); + error(call, "'%.*s' expects only arrays of non-zero length, got %s and %s", LIT(builtin_name), s1, s2); + gb_string_free(s2); + gb_string_free(s1); + return false; + } + + i64 max_count = xt->Array.count*yt->Array.count; + if (max_count > MAX_MATRIX_ELEMENT_COUNT) { + error(call, "Product of the array lengths exceed the maximum matrix element count, got %d, expected a maximum of %d", cast(int)max_count, MAX_MATRIX_ELEMENT_COUNT); + return false; + } + + operand->mode = Addressing_Value; + operand->type = alloc_type_matrix(xt->Array.elem, xt->Array.count, yt->Array.count); + operand->type = check_matrix_type_hint(operand->type, type_hint); + break; + } + case BuiltinProc_simd_vector: { Operand x = {}; diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp index 21a33bdd3..2c7392b09 100644 --- a/src/checker_builtin_procs.hpp +++ b/src/checker_builtin_procs.hpp @@ -36,6 +36,7 @@ enum BuiltinProcId { BuiltinProc_soa_unzip, BuiltinProc_transpose, + BuiltinProc_outer_product, BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures @@ -278,6 +279,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = { {STR_LIT("soa_unzip"), 1, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT("transpose"), 1, false, Expr_Expr, BuiltinProcPkg_builtin}, + {STR_LIT("outer_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin}, {STR_LIT(""), 0, true, Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index d41a0a127..27f12a829 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -522,9 +522,41 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { } } return lb_addr_load(p, res); +} + + +lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) { + Type *mt = base_type(type); + Type *at = base_type(a.type); + Type *bt = base_type(b.type); + GB_ASSERT(mt->kind == Type_Matrix); + GB_ASSERT(at->kind == Type_Array); + GB_ASSERT(bt->kind == Type_Array); + + + i64 row_count = mt->Matrix.row_count; + i64 column_count = mt->Matrix.column_count; + + GB_ASSERT(row_count == at->Array.count); + GB_ASSERT(column_count == bt->Array.count); + + + lbAddr res = lb_add_local_generated(p, type, true); + + for (i64 j = 0; j < column_count; j++) { + for (i64 i = 0; i < row_count; i++) { + lbValue x = lb_emit_struct_ev(p, a, cast(i32)i); + lbValue y = lb_emit_struct_ev(p, b, cast(i32)j); + lbValue src = lb_emit_arith(p, Token_Mul, x, y, mt->Matrix.elem); + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + lb_emit_store(p, dst, src); + } + } + return lb_addr_load(p, res); } + lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp index 1431fffaa..5a7fc1626 100644 --- a/src/llvm_backend_proc.cpp +++ b/src/llvm_backend_proc.cpp @@ -1263,6 +1263,14 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv, lbValue m = lb_build_expr(p, ce->args[0]); return lb_emit_matrix_tranpose(p, m, tv.type); } + + case BuiltinProc_outer_product: + { + lbValue a = lb_build_expr(p, ce->args[0]); + lbValue b = lb_build_expr(p, ce->args[1]); + return lb_emit_outer_product(p, a, b, tv.type); + } + // "Intrinsics" |