aboutsummaryrefslogtreecommitdiff
path: root/src/check_builtin.cpp
diff options
context:
space:
mode:
authorgingerBill <bill@gingerbill.org>2021-10-20 02:06:56 +0100
committergingerBill <bill@gingerbill.org>2021-10-20 02:06:56 +0100
commit68afbb37f40b10fd01dda9e5640cc7ae2535a371 (patch)
tree57aacd7b2e073e077d68ef3143e15055b72b0198 /src/check_builtin.cpp
parent7faca7066c30d6e663b268dc1e8ec66710ae3dd5 (diff)
Add builtin `outer_product`
Diffstat (limited to 'src/check_builtin.cpp')
-rw-r--r--src/check_builtin.cpp60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp
index 659a74ad7..1d033932f 100644
--- a/src/check_builtin.cpp
+++ b/src/check_builtin.cpp
@@ -2017,6 +2017,66 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
operand->type = check_matrix_type_hint(operand->type, type_hint);
break;
}
+
+ case BuiltinProc_outer_product: {
+ Operand x = {};
+ Operand y = {};
+ check_expr(c, &x, ce->args[0]);
+ if (x.mode == Addressing_Invalid) {
+ return false;
+ }
+ check_expr(c, &y, ce->args[1]);
+ if (y.mode == Addressing_Invalid) {
+ return false;
+ }
+ if (!is_operand_value(x) || !is_operand_value(y)) {
+ error(call, "'%.*s' expects only arrays", LIT(builtin_name));
+ return false;
+ }
+
+ if (!is_type_array(x.type) && !is_type_array(y.type)) {
+ gbString s1 = type_to_string(x.type);
+ gbString s2 = type_to_string(y.type);
+ error(call, "'%.*s' expects only arrays, got %s and %s", LIT(builtin_name), s1, s2);
+ gb_string_free(s2);
+ gb_string_free(s1);
+ return false;
+ }
+
+ Type *xt = base_type(x.type);
+ Type *yt = base_type(y.type);
+ GB_ASSERT(xt->kind == Type_Array);
+ GB_ASSERT(yt->kind == Type_Array);
+ if (!are_types_identical(xt->Array.elem, yt->Array.elem)) {
+ gbString s1 = type_to_string(xt->Array.elem);
+ gbString s2 = type_to_string(yt->Array.elem);
+ error(call, "'%.*s' mismatched element types, got %s vs %s", LIT(builtin_name), s1, s2);
+ gb_string_free(s2);
+ gb_string_free(s1);
+ return false;
+ }
+
+ if (xt->Array.count == 0 || yt->Array.count == 0) {
+ gbString s1 = type_to_string(x.type);
+ gbString s2 = type_to_string(y.type);
+ error(call, "'%.*s' expects only arrays of non-zero length, got %s and %s", LIT(builtin_name), s1, s2);
+ gb_string_free(s2);
+ gb_string_free(s1);
+ return false;
+ }
+
+ i64 max_count = xt->Array.count*yt->Array.count;
+ if (max_count > MAX_MATRIX_ELEMENT_COUNT) {
+ error(call, "Product of the array lengths exceed the maximum matrix element count, got %d, expected a maximum of %d", cast(int)max_count, MAX_MATRIX_ELEMENT_COUNT);
+ return false;
+ }
+
+ operand->mode = Addressing_Value;
+ operand->type = alloc_type_matrix(xt->Array.elem, xt->Array.count, yt->Array.count);
+ operand->type = check_matrix_type_hint(operand->type, type_hint);
+ break;
+ }
+
case BuiltinProc_simd_vector: {
Operand x = {};