aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorgingerBill <bill@gingerbill.org>2021-10-20 02:18:30 +0100
committergingerBill <bill@gingerbill.org>2021-10-20 02:18:30 +0100
commitcee45c1b155fcc917c2b0f9cfdbfa060304255e1 (patch)
treeadb7cf8841ca8523e14e3c276580b9905777059f /src
parent68afbb37f40b10fd01dda9e5640cc7ae2535a371 (diff)
Add `hadamard_product`
Diffstat (limited to 'src')
-rw-r--r--src/check_builtin.cpp56
-rw-r--r--src/check_type.cpp12
-rw-r--r--src/checker_builtin_procs.hpp2
-rw-r--r--src/llvm_backend_expr.cpp6
-rw-r--r--src/llvm_backend_proc.cpp10
-rw-r--r--src/types.cpp11
6 files changed, 84 insertions, 13 deletions
diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp
index 1d033932f..a9427d4e0 100644
--- a/src/check_builtin.cpp
+++ b/src/check_builtin.cpp
@@ -2056,6 +2056,14 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
return false;
}
+ Type *elem = xt->Array.elem;
+
+ if (!is_type_valid_for_matrix_elems(elem)) {
+ gbString s = type_to_string(elem);
+ error(call, "Matrix elements types are limited to integers, floats, and complex, got %s", s);
+ gb_string_free(s);
+ }
+
if (xt->Array.count == 0 || yt->Array.count == 0) {
gbString s1 = type_to_string(x.type);
gbString s2 = type_to_string(y.type);
@@ -2072,7 +2080,53 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
}
operand->mode = Addressing_Value;
- operand->type = alloc_type_matrix(xt->Array.elem, xt->Array.count, yt->Array.count);
+ operand->type = alloc_type_matrix(elem, xt->Array.count, yt->Array.count);
+ operand->type = check_matrix_type_hint(operand->type, type_hint);
+ break;
+ }
+
+ case BuiltinProc_hadamard_product: {
+ Operand x = {};
+ Operand y = {};
+ check_expr(c, &x, ce->args[0]);
+ if (x.mode == Addressing_Invalid) {
+ return false;
+ }
+ check_expr(c, &y, ce->args[1]);
+ if (y.mode == Addressing_Invalid) {
+ return false;
+ }
+ if (!is_operand_value(x) || !is_operand_value(y)) {
+ error(call, "'%.*s' expects a matrix or array types", LIT(builtin_name));
+ return false;
+ }
+ if (!is_type_matrix(x.type) && !is_type_array(y.type)) {
+ gbString s1 = type_to_string(x.type);
+ gbString s2 = type_to_string(y.type);
+ error(call, "'%.*s' expects matrix or array values, got %s and %s", LIT(builtin_name), s1, s2);
+ gb_string_free(s2);
+ gb_string_free(s1);
+ return false;
+ }
+
+ if (!are_types_identical(x.type, y.type)) {
+ gbString s1 = type_to_string(x.type);
+ gbString s2 = type_to_string(y.type);
+ error(call, "'%.*s' values of the same type, got %s and %s", LIT(builtin_name), s1, s2);
+ gb_string_free(s2);
+ gb_string_free(s1);
+ return false;
+ }
+
+ Type *elem = core_array_type(x.type);
+ if (!is_type_valid_for_matrix_elems(elem)) {
+ gbString s = type_to_string(elem);
+ error(call, "'%.*s' expects elements to be types are limited to integers, floats, and complex, got %s", LIT(builtin_name), s);
+ gb_string_free(s);
+ }
+
+ operand->mode = Addressing_Value;
+ operand->type = x.type;
operand->type = check_matrix_type_hint(operand->type, type_hint);
break;
}
diff --git a/src/check_type.cpp b/src/check_type.cpp
index e752f192d..d9302c65a 100644
--- a/src/check_type.cpp
+++ b/src/check_type.cpp
@@ -997,8 +997,8 @@ void check_bit_set_type(CheckerContext *c, Type *type, Type *named_type, Ast *no
GB_ASSERT(lower <= upper);
- i64 bits = MAX_BITS;
- if (bs->underlying != nullptr) {
+ i64 bits = MAX_BITS
+; if (bs->underlying != nullptr) {
Type *u = check_type(c, bs->underlying);
if (!is_type_integer(u)) {
gbString ts = type_to_string(u);
@@ -2239,13 +2239,7 @@ void check_matrix_type(CheckerContext *ctx, Type **type, Ast *node) {
error(column.expr, "Matrix types are limited to a maximum of %d elements, got %lld", MAX_MATRIX_ELEMENT_COUNT, cast(long long)element_count);
}
- if (is_type_integer(elem)) {
- // okay
- } else if (is_type_float(elem)) {
- // okay
- } else if (is_type_complex(elem)) {
- // okay
- } else {
+ if (!is_type_valid_for_matrix_elems(elem)) {
gbString s = type_to_string(elem);
error(column.expr, "Matrix elements types are limited to integers, floats, and complex, got %s", s);
gb_string_free(s);
diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp
index 2c7392b09..de4e99d14 100644
--- a/src/checker_builtin_procs.hpp
+++ b/src/checker_builtin_procs.hpp
@@ -37,6 +37,7 @@ enum BuiltinProcId {
BuiltinProc_transpose,
BuiltinProc_outer_product,
+ BuiltinProc_hadamard_product,
BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures
@@ -280,6 +281,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
{STR_LIT("transpose"), 1, false, Expr_Expr, BuiltinProcPkg_builtin},
{STR_LIT("outer_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin},
+ {STR_LIT("hadamard_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin},
{STR_LIT(""), 0, true, Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp
index 27f12a829..b894bc7b8 100644
--- a/src/llvm_backend_expr.cpp
+++ b/src/llvm_backend_expr.cpp
@@ -672,13 +672,13 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type
-lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) {
+lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type, bool component_wise=false) {
GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type));
Type *xt = base_type(lhs.type);
Type *yt = base_type(rhs.type);
- if (op == Token_Mul) {
+ if (op == Token_Mul && !component_wise) {
if (xt->kind == Type_Matrix) {
if (yt->kind == Type_Matrix) {
return lb_emit_matrix_mul(p, lhs, rhs, type);
@@ -703,7 +703,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue
array_lhs.type = array_type;
array_rhs.type = array_type;
- lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, type);
+ lbValue array = lb_emit_arith_array(p, op, array_lhs, array_rhs, array_type);
array.type = type;
return array;
}
diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp
index 5a7fc1626..da4e4ad28 100644
--- a/src/llvm_backend_proc.cpp
+++ b/src/llvm_backend_proc.cpp
@@ -1270,6 +1270,16 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
lbValue b = lb_build_expr(p, ce->args[1]);
return lb_emit_outer_product(p, a, b, tv.type);
}
+ case BuiltinProc_hadamard_product:
+ {
+ lbValue a = lb_build_expr(p, ce->args[0]);
+ lbValue b = lb_build_expr(p, ce->args[1]);
+ if (is_type_array(tv.type)) {
+ return lb_emit_arith(p, Token_Mul, a, b, tv.type);
+ }
+ GB_ASSERT(is_type_matrix(tv.type));
+ return lb_emit_arith_matrix(p, Token_Mul, a, b, tv.type, true);
+ }
// "Intrinsics"
diff --git a/src/types.cpp b/src/types.cpp
index eaf1bac74..32e26bcc6 100644
--- a/src/types.cpp
+++ b/src/types.cpp
@@ -1333,6 +1333,17 @@ i64 matrix_indices_to_offset(Type *t, i64 row_index, i64 column_index) {
return stride_elems*column_index + row_index;
}
+bool is_type_valid_for_matrix_elems(Type *t) {
+ if (is_type_integer(t)) {
+ return true;
+ } else if (is_type_float(t)) {
+ return true;
+ } else if (is_type_complex(t)) {
+ return true;
+ }
+ return false;
+}
+
bool is_type_dynamic_array(Type *t) {
t = base_type(t);
return t->kind == Type_DynamicArray;