diff options
| author | gingerBill <bill@gingerbill.org> | 2021-10-20 12:39:38 +0100 |
|---|---|---|
| committer | gingerBill <bill@gingerbill.org> | 2021-10-20 12:39:38 +0100 |
| commit | 0fd525d7789d2a0786b28677c5dd4cbd263f4537 (patch) | |
| tree | 094ffc5afbd43675652450d8c2f428b0ae82c849 /src | |
| parent | 07bf64ae5243d3e2f38da9cf9da81ef7a99a6f44 (diff) | |
Make `lb_emit_matrix_mul_vector` use SIMD if possible
Diffstat (limited to 'src')
| -rw-r--r-- | src/llvm_backend_expr.cpp | 68 | ||||
| -rw-r--r-- | src/llvm_backend_utility.cpp | 32 |
2 files changed, 97 insertions, 3 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index b894bc7b8..6cb221a94 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -567,11 +567,10 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); - if (!lb_matrix_elem_simple(xt)) { - goto slow_form; + if (lb_matrix_elem_simple(xt)) { + // TODO(bill): SIMD version } -slow_form: { Type *elem = xt->Matrix.elem; @@ -618,6 +617,69 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type GB_ASSERT(are_types_identical(mt->Matrix.elem, base_array_type(vt))); Type *elem = mt->Matrix.elem; + LLVMTypeRef elem_type = lb_type(p->module, elem); + + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + + if (lb_matrix_elem_simple(mt)) { + unsigned row_count = cast(unsigned)mt->Matrix.row_count; gb_unused(row_count); + unsigned column_count = cast(unsigned)mt->Matrix.column_count; + auto m_columns = slice_make<LLVMValueRef>(permanent_allocator(), column_count); + auto v_rows = slice_make<LLVMValueRef>(permanent_allocator(), column_count); + + unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); + LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); + + LLVMValueRef lhs_ptr = lb_address_from_load_or_generate_local(p, lhs).value; + LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, lhs_ptr, LLVMPointerType(total_matrix_type, 0), ""); + LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, ""); + + + for (unsigned column_index = 0; column_index < column_count; column_index++) { + LLVMValueRef mask = llvm_mask_iota(p->module, stride*column_index, row_count); + LLVMValueRef column = LLVMBuildShuffleVector(p->builder, matrix_vector, LLVMGetUndef(LLVMTypeOf(matrix_vector)), mask, ""); + m_columns[column_index] = column; + } + + for (unsigned row_index = 0; row_index < column_count; row_index++) { + LLVMValueRef value = lb_emit_struct_ev(p, rhs, row_index).value; + LLVMValueRef row = llvm_splat(p, value, row_count); + v_rows[row_index] = row; + } + + GB_ASSERT(column_count > 0); + + LLVMValueRef vector = nullptr; + if (is_type_float(elem)) { + for (i64 i = 0; i < column_count; i++) { + LLVMValueRef product = LLVMBuildFMul(p->builder, m_columns[i], v_rows[i], ""); + if (i == 0) { + vector = product; + } else { + vector = LLVMBuildFAdd(p->builder, vector, product, ""); + } + } + } else { + for (i64 i = 0; i < column_count; i++) { + LLVMValueRef product = LLVMBuildMul(p->builder, m_columns[i], v_rows[i], ""); + if (i == 0) { + vector = product; + } else { + vector = LLVMBuildAdd(p->builder, vector, product, ""); + } + } + } + + lbAddr res = lb_add_local_generated(p, type, true); + LLVMValueRef res_ptr = res.addr.value; + unsigned alignment = cast(unsigned)gb_max(type_align_of(type), lb_alignof(LLVMTypeOf(vector))); + LLVMSetAlignment(res_ptr, alignment); + + res_ptr = LLVMBuildPointerCast(p->builder, res_ptr, LLVMPointerType(LLVMTypeOf(vector), 0), ""); + LLVMBuildStore(p->builder, vector, res_ptr); + + return lb_addr_load(p, res); + } lbAddr res = lb_add_local_generated(p, type, true); diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index fb9264661..56637e907 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1512,4 +1512,36 @@ lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t lbValue y = lb_emit_arith(p, Token_Add, x, c, t); return y; } +} + +LLVMValueRef llvm_mask_iota(lbModule *m, unsigned start, unsigned count) { + auto iota = slice_make<LLVMValueRef>(temporary_allocator(), count); + for (unsigned i = 0; i < count; i++) { + iota[i] = lb_const_int(m, t_u32, start+i).value; + } + return LLVMConstVector(iota.data, count); +} + +LLVMValueRef llvm_mask_zero(lbModule *m, unsigned count) { + return LLVMConstNull(LLVMVectorType(lb_type(m, t_u32), count)); +} + +LLVMValueRef llvm_splat(lbProcedure *p, LLVMValueRef value, unsigned count) { + GB_ASSERT(count > 0); + if (LLVMIsConstant(value)) { + LLVMValueRef single = LLVMConstVector(&value, 1); + if (count == 1) { + return single; + } + LLVMValueRef mask = llvm_mask_zero(p->module, count); + return LLVMConstShuffleVector(single, LLVMGetUndef(LLVMTypeOf(single)), mask); + } + + LLVMTypeRef single_type = LLVMVectorType(LLVMTypeOf(value), 1); + LLVMValueRef single = LLVMBuildBitCast(p->builder, value, single_type, ""); + if (count == 1) { + return single; + } + LLVMValueRef mask = llvm_mask_zero(p->module, count); + return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, ""); }
\ No newline at end of file |