aboutsummaryrefslogtreecommitdiff
path: root/src/llvm_backend_expr.cpp
diff options
context:
space:
mode:
authorgingerBill <bill@gingerbill.org>2021-10-20 01:51:16 +0100
committergingerBill <bill@gingerbill.org>2021-10-20 01:51:16 +0100
commit7faca7066c30d6e663b268dc1e8ec66710ae3dd5 (patch)
tree0a1b66625ec88dfc4abbc2e7ae0a1e0336eafebf /src/llvm_backend_expr.cpp
parent3eaac057da11d28cbedd7321f9f6368588b0b4ee (diff)
Add builtin `transpose`
Diffstat (limited to 'src/llvm_backend_expr.cpp')
-rw-r--r--src/llvm_backend_expr.cpp133
1 files changed, 17 insertions, 116 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp
index 518ce33af..d41a0a127 100644
--- a/src/llvm_backend_expr.cpp
+++ b/src/llvm_backend_expr.cpp
@@ -502,116 +502,29 @@ bool lb_matrix_elem_simple(Type *t) {
return true;
}
-LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) {
- lbModule *m = p->module;
-
- Type *mt = base_type(lhs.type);
- GB_ASSERT(mt->kind == Type_Matrix);
- GB_ASSERT(lb_matrix_elem_simple(mt));
-
-
- i64 stride = matrix_type_stride_in_elems(mt);
- i64 rows = mt->Matrix.row_count;
- i64 columns = mt->Matrix.column_count;
- unsigned elem_count = cast(unsigned)(rows*columns);
-
- Type *elem = mt->Matrix.elem;
- LLVMTypeRef elem_type = lb_type(m, elem);
-
- LLVMTypeRef vector_type = LLVMVectorType(elem_type, elem_count);
- LLVMTypeRef types[] = {vector_type};
-
- char const *name = "llvm.matrix.column.major.load";
- unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
- GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
- LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-
- lbValue ptr = lb_address_from_load_or_generate_local(p, lhs);
- ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
-
- LLVMValueRef values[5] = {};
- values[0] = ptr.value;
- values[1] = lb_const_int(m, t_u64, stride).value;
- values[2] = LLVMConstNull(lb_type(m, t_llvm_bool));
- values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
- values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
-
- LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
- gb_printf_err("%s\n", LLVMPrintValueToString(call));
- // LLVMAddAttributeAtIndex(call, 0, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
- return call;
-}
-
-void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) {
- lbModule *m = p->module;
-
- Type *mt = base_type(lb_addr_type(addr));
+lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) {
+ if (is_type_array(m.type)) {
+ m.type = type;
+ return m;
+ }
+ Type *mt = base_type(m.type);
GB_ASSERT(mt->kind == Type_Matrix);
- GB_ASSERT(lb_matrix_elem_simple(mt));
-
- LLVMTypeRef vector_type = LLVMTypeOf(vector_value);
- LLVMTypeRef types[] = {vector_type};
-
- char const *name = "llvm.matrix.column.major.store";
- unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
- GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
- LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-
- lbValue ptr = lb_addr_get_ptr(p, addr);
- ptr = lb_emit_matrix_epi(p, ptr, 0, 0);
-
- unsigned vector_size = LLVMGetVectorSize(vector_type);
- GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size);
- i64 stride = matrix_type_stride_in_elems(mt);
-
- LLVMValueRef values[6] = {};
- values[0] = vector_value;
- values[1] = ptr.value;
- values[2] = lb_const_int(m, t_u64, stride).value;
- values[3] = LLVMConstNull(lb_type(m, t_llvm_bool));
- values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value;
- values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value;
+ lbAddr res = lb_add_local_generated(p, type, true);
- LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
- gb_printf_err("%s\n", LLVMPrintValueToString(call));
- // LLVMAddAttributeAtIndex(call, 1, lb_create_enum_attribute(p->module->ctx, "align", cast(u64)type_align_of(mt)));
- gb_unused(call);
-}
-
+ i64 row_count = mt->Matrix.row_count;
+ i64 column_count = mt->Matrix.column_count;
+ for (i64 j = 0; j < column_count; j++) {
+ for (i64 i = 0; i < row_count; i++) {
+ lbValue src = lb_emit_matrix_ev(p, m, i, j);
+ lbValue dst = lb_emit_matrix_epi(p, res.addr, j, i);
+ lb_emit_store(p, dst, src);
+ }
+ }
+ return lb_addr_load(p, res);
-LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) {
- lbModule *m = p->module;
-
- LLVMTypeRef a_type = LLVMTypeOf(a);
- LLVMTypeRef b_type = LLVMTypeOf(b);
-
- GB_ASSERT(LLVMGetElementType(a_type) == LLVMGetElementType(b_type));
-
- LLVMTypeRef elem_type = LLVMGetElementType(a_type);
-
- LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns));
-
- LLVMTypeRef types[] = {res_vector_type, a_type, b_type};
-
- char const *name = "llvm.matrix.multiply";
- unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
- GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
- LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
-
- LLVMValueRef values[5] = {};
- values[0] = a;
- values[1] = b;
- values[2] = lb_const_int(m, t_u32, outer_rows).value;
- values[3] = lb_const_int(m, t_u32, inner).value;
- values[4] = lb_const_int(m, t_u32, outer_columns).value;
-
- LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
- gb_printf_err("%s\n", LLVMPrintValueToString(call));
- return call;
}
-
lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
Type *xt = base_type(lhs.type);
Type *yt = base_type(rhs.type);
@@ -626,18 +539,6 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
goto slow_form;
}
- if (false) {
- // TODO(bill): LLVM ERROR: Do not know how to split the result of this operator!
- lbAddr res = lb_add_local_generated(p, type, true);
-
- LLVMValueRef a = llvm_matrix_column_major_load(p, lhs); gb_unused(a);
- LLVMValueRef b = llvm_matrix_column_major_load(p, rhs); gb_unused(b);
- LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count); gb_unused(c);
- llvm_matrix_column_major_store(p, res, c);
-
- return lb_addr_load(p, res);
- }
-
slow_form:
{
Type *elem = xt->Matrix.elem;