aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorgingerBill <bill@gingerbill.org>2022-02-06 11:42:59 +0000
committergingerBill <bill@gingerbill.org>2022-02-06 11:42:59 +0000
commit19aec13a1060a521913abc6bd669080171d43594 (patch)
tree2995c7f26a5a1fc19a977845542281481fc12399 /src
parente896956275bf7177e93795ab3cbb7069e5d05ff2 (diff)
Support rank-2 arrays (matrix-like) for `transpose`
Diffstat (limited to 'src')
-rw-r--r--src/check_builtin.cpp38
-rw-r--r--src/llvm_backend_expr.cpp21
-rw-r--r--src/types.cpp19
3 files changed, 76 insertions, 2 deletions
diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp
index a42741976..d3a3103b1 100644
--- a/src/check_builtin.cpp
+++ b/src/check_builtin.cpp
@@ -2183,9 +2183,43 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
}
operand->mode = Addressing_Value;
- if (is_type_array(t)) {
+ if (t->kind == Type_Array) {
+ i32 rank = type_math_rank(t);
// Do nothing
- operand->type = x.type;
+ operand->type = x.type;
+ if (rank > 2) {
+ gbString s = type_to_string(x.type);
+ error(call, "'%.*s' expects a matrix or array with a rank of 2, got %s of rank %d", LIT(builtin_name), s, rank);
+ gb_string_free(s);
+ return false;
+ } else if (rank == 2) {
+ Type *inner = base_type(t->Array.elem);
+ GB_ASSERT(inner->kind == Type_Array);
+ Type *elem = inner->Array.elem;
+ Type *array_inner = alloc_type_array(elem, t->Array.count);
+ Type *array_outer = alloc_type_array(array_inner, inner->Array.count);
+ operand->type = array_outer;
+
+ i64 elements = t->Array.count*inner->Array.count;
+ i64 size = type_size_of(operand->type);
+ if (!is_type_valid_for_matrix_elems(elem)) {
+ gbString s = type_to_string(x.type);
+ error(call, "'%.*s' expects a matrix or array with a base element type of an integer, float, or complex number, got %s", LIT(builtin_name), s);
+ gb_string_free(s);
+ } else if (elements > MATRIX_ELEMENT_COUNT_MAX) {
+ gbString s = type_to_string(x.type);
+ error(call, "'%.*s' expects a matrix or array with a maximum of %d elements, got %s with %lld elements", LIT(builtin_name), MATRIX_ELEMENT_COUNT_MAX, s, elements);
+ gb_string_free(s);
+ } else if (elements > MATRIX_ELEMENT_COUNT_MAX) {
+ gbString s = type_to_string(x.type);
+ error(call, "'%.*s' expects a matrix or array with non-zero elements, got %s", LIT(builtin_name), MATRIX_ELEMENT_COUNT_MAX, s);
+ gb_string_free(s);
+ } else if (size > MATRIX_ELEMENT_MAX_SIZE) {
+ gbString s = type_to_string(x.type);
+ error(call, "Too large of a type for '%.*s', got %s of size %lld, maximum size %d", LIT(builtin_name), s, cast(long long)size, MATRIX_ELEMENT_MAX_SIZE);
+ gb_string_free(s);
+ }
+ }
} else {
GB_ASSERT(t->kind == Type_Matrix);
operand->type = alloc_type_matrix(t->Matrix.elem, t->Matrix.column_count, t->Matrix.row_count);
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp
index 715b7df78..29a86d116 100644
--- a/src/llvm_backend_expr.cpp
+++ b/src/llvm_backend_expr.cpp
@@ -580,6 +580,27 @@ LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
if (is_type_array(m.type)) {
+ i32 rank = type_math_rank(m.type);
+ if (rank == 2) {
+ lbAddr addr = lb_add_local_generated(p, type, false);
+ lbValue dst = addr.addr;
+ lbValue src = m;
+ i32 n = cast(i32)get_array_type_count(m.type);
+ i32 m = cast(i32)get_array_type_count(type);
+ // m.type == [n][m]T
+ // type == [m][n]T
+
+ for (i32 j = 0; j < m; j++) {
+ lbValue dst_col = lb_emit_struct_ep(p, dst, j);
+ for (i32 i = 0; i < n; i++) {
+ lbValue dst_row = lb_emit_struct_ep(p, dst_col, i);
+ lbValue src_col = lb_emit_struct_ev(p, src, i);
+ lbValue src_row = lb_emit_struct_ev(p, src_col, j);
+ lb_emit_store(p, dst_row, src_row);
+ }
+ }
+ return lb_addr_load(p, addr);
+ }
// no-op
m.type = type;
return m;
diff --git a/src/types.cpp b/src/types.cpp
index e0d35a12c..9ee6ba359 100644
--- a/src/types.cpp
+++ b/src/types.cpp
@@ -363,6 +363,7 @@ enum TypeInfoFlag : u32 {
enum : int {
MATRIX_ELEMENT_COUNT_MIN = 1,
MATRIX_ELEMENT_COUNT_MAX = 16,
+ MATRIX_ELEMENT_MAX_SIZE = MATRIX_ELEMENT_COUNT_MAX * (2 * 8), // complex128
};
@@ -1583,6 +1584,24 @@ Type *core_array_type(Type *t) {
}
}
+i32 type_math_rank(Type *t) {
+ i32 rank = 0;
+ for (;;) {
+ t = base_type(t);
+ switch (t->kind) {
+ case Type_Array:
+ rank += 1;
+ t = t->Array.elem;
+ break;
+ case Type_Matrix:
+ rank += 2;
+ t = t->Matrix.elem;
+ break;
+ default:
+ return rank;
+ }
+ }
+}
Type *base_complex_elem_type(Type *t) {