aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/check_builtin.cpp30
-rw-r--r--src/checker_builtin_procs.hpp2
-rw-r--r--src/llvm_backend_expr.cpp77
-rw-r--r--src/llvm_backend_proc.cpp6
4 files changed, 106 insertions, 9 deletions
diff --git a/src/check_builtin.cpp b/src/check_builtin.cpp
index a9427d4e0..b60509c03 100644
--- a/src/check_builtin.cpp
+++ b/src/check_builtin.cpp
@@ -2131,6 +2131,36 @@ bool check_builtin_procedure(CheckerContext *c, Operand *operand, Ast *call, i32
break;
}
+ case BuiltinProc_matrix_flatten: {
+ Operand x = {};
+ check_expr(c, &x, ce->args[0]);
+ if (x.mode == Addressing_Invalid) {
+ return false;
+ }
+ if (!is_operand_value(x)) {
+ error(call, "'%.*s' expects a matrix or array", LIT(builtin_name));
+ return false;
+ }
+ Type *t = base_type(x.type);
+ if (!is_type_matrix(t) && !is_type_array(t)) {
+ gbString s = type_to_string(x.type);
+ error(call, "'%.*s' expects a matrix or array, got %s", LIT(builtin_name), s);
+ gb_string_free(s);
+ return false;
+ }
+
+ operand->mode = Addressing_Value;
+ if (is_type_array(t)) {
+ // Do nothing
+ operand->type = x.type;
+ } else {
+ GB_ASSERT(t->kind == Type_Matrix);
+ operand->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count*t->Matrix.column_count);
+ }
+ operand->type = check_matrix_type_hint(operand->type, type_hint);
+ break;
+ }
+
case BuiltinProc_simd_vector: {
Operand x = {};
diff --git a/src/checker_builtin_procs.hpp b/src/checker_builtin_procs.hpp
index de4e99d14..5594c1a1a 100644
--- a/src/checker_builtin_procs.hpp
+++ b/src/checker_builtin_procs.hpp
@@ -38,6 +38,7 @@ enum BuiltinProcId {
BuiltinProc_transpose,
BuiltinProc_outer_product,
BuiltinProc_hadamard_product,
+ BuiltinProc_matrix_flatten,
BuiltinProc_DIRECTIVE, // NOTE(bill): This is used for specialized hash-prefixed procedures
@@ -282,6 +283,7 @@ gb_global BuiltinProc builtin_procs[BuiltinProc_COUNT] = {
{STR_LIT("transpose"), 1, false, Expr_Expr, BuiltinProcPkg_builtin},
{STR_LIT("outer_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin},
{STR_LIT("hadamard_product"), 2, false, Expr_Expr, BuiltinProcPkg_builtin},
+ {STR_LIT("matrix_flatten"), 1, false, Expr_Expr, BuiltinProcPkg_builtin},
{STR_LIT(""), 0, true, Expr_Expr, BuiltinProcPkg_builtin}, // DIRECTIVE
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp
index c1bdceba6..7d1c8e3db 100644
--- a/src/llvm_backend_expr.cpp
+++ b/src/llvm_backend_expr.cpp
@@ -517,6 +517,33 @@ LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) {
return matrix_vector;
}
+LLVMValueRef lb_matrix_to_trimmed_vector(lbProcedure *p, lbValue m) {
+ Type *mt = base_type(m.type);
+ GB_ASSERT(mt->kind == Type_Matrix);
+
+ unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt);
+ unsigned row_count = cast(unsigned)mt->Matrix.row_count;
+ unsigned column_count = cast(unsigned)mt->Matrix.column_count;
+
+ auto columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count);
+
+ LLVMValueRef vector = lb_matrix_to_vector(p, m);
+
+ unsigned mask_elems_index = 0;
+ auto mask_elems = slice_make<LLVMValueRef>(permanent_allocator(), row_count*column_count);
+ for (unsigned j = 0; j < column_count; j++) {
+ for (unsigned i = 0; i < row_count; i++) {
+ unsigned offset = stride*j + i;
+ mask_elems[mask_elems_index++] = lb_const_int(p->module, t_u32, offset).value;
+ }
+ }
+
+ LLVMValueRef mask = LLVMConstVector(mask_elems.data, cast(unsigned)mask_elems.count);
+ LLVMValueRef trimmed_vector = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, "");
+ return trimmed_vector;
+}
+
+
lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
if (is_type_array(m.type)) {
// no-op
@@ -573,6 +600,46 @@ lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
return lb_addr_load(p, res);
}
+lbValue lb_matrix_cast_vector_to_type(lbProcedure *p, LLVMValueRef vector, Type *type) {
+ lbAddr res = lb_add_local_generated(p, type, true);
+ LLVMValueRef res_ptr = res.addr.value;
+ unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
+ LLVMSetAlignment(res_ptr, alignment);
+
+ res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
+ LLVMBuildStore(p->builder, vector, res_ptr);
+
+ return lb_addr_load(p, res);
+}
+
+lbValue lb_emit_matrix_flatten(lbProcedure *p, lbValue m, Type *type) {
+ if (is_type_array(m.type)) {
+ // no-op
+ m.type = type;
+ return m;
+ }
+ Type *mt = base_type(m.type);
+ GB_ASSERT(mt->kind == Type_Matrix);
+
+ if (lb_matrix_elem_simple(mt)) {
+ LLVMValueRef vector = lb_matrix_to_trimmed_vector(p, m);
+ return lb_matrix_cast_vector_to_type(p, vector, type);
+ }
+
+ lbAddr res = lb_add_local_generated(p, type, true);
+
+ i64 row_count = mt->Matrix.row_count;
+ i64 column_count = mt->Matrix.column_count;
+ for (i64 j = 0; j < column_count; j++) {
+ for (i64 i = 0; i < row_count; i++) {
+ lbValue src = lb_emit_matrix_ev(p, m, i, j);
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j);
+ lb_emit_store(p, dst, src);
+ }
+ }
+ return lb_addr_load(p, res);
+}
+
lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) {
Type *mt = base_type(type);
@@ -737,16 +804,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
vector = llvm_vector_add(p, vector, product);
}
}
-
- lbAddr res = lb_add_local_generated(p, type, true);
- LLVMValueRef res_ptr = res.addr.value;
- unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector)));
- LLVMSetAlignment(res_ptr, alignment);
-
- res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), "");
- LLVMBuildStore(p->builder, vector, res_ptr);
- return lb_addr_load(p, res);
+ return lb_matrix_cast_vector_to_type(p, vector, type);
}
lbAddr res = lb_add_local_generated(p, type, true);
diff --git a/src/llvm_backend_proc.cpp b/src/llvm_backend_proc.cpp
index da4e4ad28..8686b3262 100644
--- a/src/llvm_backend_proc.cpp
+++ b/src/llvm_backend_proc.cpp
@@ -1280,6 +1280,12 @@ lbValue lb_build_builtin_proc(lbProcedure *p, Ast *expr, TypeAndValue const &tv,
GB_ASSERT(is_type_matrix(tv.type));
return lb_emit_arith_matrix(p, Token_Mul, a, b, tv.type, true);
}
+
+ case BuiltinProc_matrix_flatten:
+ {
+ lbValue m = lb_build_expr(p, ce->args[0]);
+ return lb_emit_matrix_flatten(p, m, tv.type);
+ }
// "Intrinsics"