From 465c87bd5a38488ae7b177a10ecf93f05ec18e9d Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 15:22:02 +0100 Subject: Make `transpose` use SIMD if possible --- src/llvm_backend_expr.cpp | 73 +++++++++++++++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 21 deletions(-) (limited to 'src/llvm_backend_expr.cpp') diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 22e66c147..c1bdceba6 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -495,21 +495,70 @@ bool lb_matrix_elem_simple(Type *t) { case Basic_f16le: case Basic_f16be: // TODO(bill): determine when this is fine - return false; + return true; } } return true; } + +LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) { + Type *mt = base_type(matrix.type); + GB_ASSERT(mt->kind == Type_Matrix); + LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem); + + unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); + LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); + + LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value; + LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), ""); + LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, ""); + return matrix_vector; +} + lbValue lb_emit_matrix_tranpose(lbProcedure *p, lbValue m, Type *type) { if (is_type_array(m.type)) { + // no-op m.type = type; return m; } Type *mt = base_type(m.type); GB_ASSERT(mt->kind == Type_Matrix); + if (lb_matrix_elem_simple(mt)) { + unsigned stride = cast(unsigned)matrix_type_stride_in_elems(mt); + unsigned row_count = cast(unsigned)mt->Matrix.row_count; + unsigned column_count = cast(unsigned)mt->Matrix.column_count; + + auto rows = slice_make(permanent_allocator(), row_count); + auto mask_elems = slice_make(permanent_allocator(), column_count); + + LLVMValueRef vector = lb_matrix_to_vector(p, m); + for (unsigned i = 0; i < row_count; i++) { + for (unsigned j = 0; j < column_count; j++) { + unsigned offset = stride*j + i; + mask_elems[j] = lb_const_int(p->module, t_u32, offset).value; + } + + // transpose mask + LLVMValueRef mask = LLVMConstVector(mask_elems.data, column_count); + LLVMValueRef row = LLVMBuildShuffleVector(p->builder, vector, LLVMGetUndef(LLVMTypeOf(vector)), mask, ""); + rows[i] = row; + } + + lbAddr res = lb_add_local_generated(p, type, true); + for_array(i, rows) { + LLVMValueRef row = rows[i]; + lbValue dst_row_ptr = lb_emit_matrix_epi(p, res.addr, 0, i); + LLVMValueRef ptr = dst_row_ptr.value; + ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(LLVMTypeOf(row), 0), ""); + LLVMBuildStore(p->builder, row, ptr); + } + + return lb_addr_load(p, res); + } + lbAddr res = lb_add_local_generated(p, type, true); i64 row_count = mt->Matrix.row_count; @@ -556,21 +605,6 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) } - -LLVMValueRef lb_matrix_to_vector(lbProcedure *p, lbValue matrix) { - Type *mt = base_type(matrix.type); - GB_ASSERT(mt->kind == Type_Matrix); - LLVMTypeRef elem_type = lb_type(p->module, mt->Matrix.elem); - - unsigned total_count = cast(unsigned)matrix_type_total_elems(mt); - LLVMTypeRef total_matrix_type = LLVMVectorType(elem_type, total_count); - - LLVMValueRef ptr = lb_address_from_load_or_generate_local(p, matrix).value; - LLVMValueRef matrix_vector_ptr = LLVMBuildPointerCast(p->builder, ptr, LLVMPointerType(total_matrix_type, 0), ""); - LLVMValueRef matrix_vector = LLVMBuildLoad(p->builder, matrix_vector_ptr, ""); - return matrix_vector; -} - lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); @@ -594,12 +628,11 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) auto x_rows = slice_make(permanent_allocator(), outer_rows); auto y_columns = slice_make(permanent_allocator(), outer_columns); - LLVMValueRef x_vector = lb_matrix_to_vector(p, lhs); LLVMValueRef y_vector = lb_matrix_to_vector(p, rhs); + auto mask_elems = slice_make(permanent_allocator(), inner); for (unsigned i = 0; i < outer_rows; i++) { - auto mask_elems = slice_make(temporary_allocator(), inner); for (unsigned j = 0; j < inner; j++) { unsigned offset = x_stride*j + i; mask_elems[j] = lb_const_int(p->module, t_u32, offset).value; @@ -616,8 +649,6 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) LLVMValueRef column = LLVMBuildShuffleVector(p->builder, y_vector, LLVMGetUndef(LLVMTypeOf(y_vector)), mask, ""); y_columns[i] = column; } - - lbAddr res = lb_add_local_generated(p, type, true); for_array(i, x_rows) { @@ -760,8 +791,8 @@ lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type LLVMValueRef matrix_vector = lb_matrix_to_vector(p, rhs); + auto mask_elems = slice_make(permanent_allocator(), column_count); for (unsigned row_index = 0; row_index < row_count; row_index++) { - auto mask_elems = slice_make(temporary_allocator(), column_count); for (unsigned column_index = 0; column_index < column_count; column_index++) { unsigned offset = row_index + column_index*stride; mask_elems[column_index] = lb_const_int(p->module, t_u32, offset).value; -- cgit v1.2.3