diff options
Diffstat (limited to 'src/llvm_backend_expr.cpp')
| -rw-r--r-- | src/llvm_backend_expr.cpp | 273 |
1 files changed, 249 insertions, 24 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index 6b7d90ec0..2e2d45991 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -476,6 +476,254 @@ lbValue lb_emit_arith_array(lbProcedure *p, TokenKind op, lbValue lhs, lbValue r } } +bool lb_matrix_elem_simple(Type *t) { + Type *mt = base_type(t); + GB_ASSERT(mt->kind == Type_Matrix); + + Type *elem = core_type(mt->Matrix.elem); + if (is_type_complex(elem)) { + return false; + } + + if (is_type_different_to_arch_endianness(elem)) { + return false; + } + + if (elem->kind == Type_Basic) { + switch (elem->Basic.kind) { + case Basic_f16: + case Basic_f16le: + case Basic_f16be: + // TODO(bill): determine when this is fine + return false; + } + } + + return true; +} + +LLVMValueRef llvm_matrix_column_major_load(lbProcedure *p, lbValue lhs) { + lbModule *m = p->module; + + Type *mt = base_type(lhs.type); + GB_ASSERT(mt->kind == Type_Matrix); + GB_ASSERT(lb_matrix_elem_simple(mt)); + + unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt); + + Type *elem = mt->Matrix.elem; + LLVMTypeRef elem_type = lb_type(m, elem); + + LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count); + LLVMTypeRef types[] = {vector_type}; + + char const *name = "llvm.matrix.column.major.load"; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); + + lbValue ptr = lb_address_from_load_or_generate_local(p, lhs); + ptr = lb_emit_matrix_epi(p, ptr, 0, 0); + + LLVMValueRef values[5] = {}; + values[0] = ptr.value; + values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width + values[2] = LLVMConstNull(lb_type(m, t_llvm_bool)); + values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value; + values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value; + + return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); +} +LLVMValueRef llvm_matrix_column_major_load_from_ptr(lbProcedure *p, lbValue ptr) { + lbModule *m = p->module; + + Type *mt = base_type(type_deref(ptr.type)); + GB_ASSERT(mt->kind == Type_Matrix); + GB_ASSERT(lb_matrix_elem_simple(mt)); + + unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt); + + Type *elem = mt->Matrix.elem; + LLVMTypeRef elem_type = lb_type(m, elem); + + LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count); + LLVMTypeRef types[] = {vector_type}; + + char const *name = "llvm.matrix.column.major.load"; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); + + LLVMValueRef values[5] = {}; + values[0] = lb_emit_matrix_epi(p, ptr, 0, 0).value; + values[1] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width + values[2] = LLVMConstNull(lb_type(m, t_llvm_bool)); + values[3] = lb_const_int(m, t_u32, mt->Matrix.row_count).value; + values[4] = lb_const_int(m, t_u32, mt->Matrix.column_count).value; + + return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); +} + +void llvm_matrix_column_major_store(lbProcedure *p, lbAddr addr, LLVMValueRef vector_value) { + lbModule *m = p->module; + + Type *mt = base_type(lb_addr_type(addr)); + GB_ASSERT(mt->kind == Type_Matrix); + GB_ASSERT(lb_matrix_elem_simple(mt)); + + unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt); + + Type *elem = mt->Matrix.elem; + LLVMTypeRef elem_type = lb_type(m, elem); + + LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count); + LLVMTypeRef types[] = {vector_type}; + + char const *name = "llvm.matrix.column.major.store"; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); + + lbValue ptr = lb_addr_get_ptr(p, addr); + ptr = lb_emit_matrix_epi(p, ptr, 0, 0); + + GB_ASSERT(LLVMTypeOf(vector_value) == vector_type); + unsigned vector_size = LLVMGetVectorSize(vector_type); + GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size); + + LLVMValueRef values[6] = {}; + values[0] = vector_value; + values[1] = ptr.value; + values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width + values[3] = LLVMConstNull(lb_type(m, t_llvm_bool)); + values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value; + values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value; + + LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); +} + +void llvm_matrix_column_major_store_to_raw_ptr(lbProcedure *p, Type *mt, lbValue ptr, LLVMValueRef vector_value) { + lbModule *m = p->module; + + mt = base_type(mt); + GB_ASSERT(mt->kind == Type_Matrix); + GB_ASSERT(lb_matrix_elem_simple(mt)); + + unsigned total_elem_count = cast(unsigned)matrix_type_total_elems(mt); + + Type *elem = mt->Matrix.elem; + LLVMTypeRef elem_type = lb_type(m, elem); + + LLVMTypeRef vector_type = LLVMVectorType(elem_type, total_elem_count); + LLVMTypeRef types[] = {vector_type}; + + char const *name = "llvm.matrix.column.major.store"; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); + + GB_ASSERT(LLVMTypeOf(vector_value) == vector_type); + unsigned vector_size = LLVMGetVectorSize(vector_type); + GB_ASSERT((mt->Matrix.row_count*mt->Matrix.column_count) == cast(i64)vector_size); + + LLVMValueRef values[6] = {}; + values[0] = vector_value; + values[1] = ptr.value; + values[2] = lb_const_int(m, t_u64, 8*matrix_type_stride(mt)).value; // bit width + values[3] = LLVMConstNull(lb_type(m, t_llvm_bool)); + values[4] = lb_const_int(m, t_u32, mt->Matrix.row_count).value; + values[5] = lb_const_int(m, t_u32, mt->Matrix.column_count).value; + + LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); +} + +LLVMValueRef llvm_matrix_multiply(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, i64 outer_rows, i64 inner, i64 outer_columns) { + lbModule *m = p->module; + + LLVMTypeRef a_type = LLVMTypeOf(a); + LLVMTypeRef b_type = LLVMTypeOf(b); + + GB_ASSERT(LLVMGetElementType(a_type) == LLVMGetElementType(b_type)); + + LLVMTypeRef elem_type = LLVMGetElementType(a_type); + + LLVMTypeRef res_vector_type = LLVMVectorType(elem_type, cast(unsigned)(outer_rows*outer_columns)); + LLVMTypeRef types[] = {res_vector_type, a_type, b_type}; + + char const *name = "llvm.matrix.multiply"; + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); + GB_ASSERT_MSG(id != 0, "Unable to find %s", name); + LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types)); + + LLVMValueRef values[5] = {}; + values[0] = a; + values[1] = b; + values[2] = lb_const_int(m, t_u32, outer_rows).value; + values[3] = lb_const_int(m, t_u32, inner).value; + values[4] = lb_const_int(m, t_u32, outer_columns).value; + + return LLVMBuildCall(p->builder, ip, values, gb_count_of(values), ""); +} + + +lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { + Type *xt = base_type(lhs.type); + Type *yt = base_type(rhs.type); + + GB_ASSERT(is_type_matrix(type)); + GB_ASSERT(is_type_matrix(xt)); + GB_ASSERT(is_type_matrix(yt)); + GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); + GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); + + if (!lb_matrix_elem_simple(xt)) { + goto slow_form; + } + + if (false) { + // TODO(bill): LLVM ERROR: Do not know how to split the result of this operator! + lbAddr res = lb_add_local_generated(p, type, true); + + lbValue res_ptr = lb_addr_get_ptr(p, res); + res_ptr = lb_emit_matrix_epi(p, res_ptr, 0, 0); + + lbValue lhs_ptr = lb_address_from_load_or_generate_local(p, lhs); + lbValue rhs_ptr = lb_address_from_load_or_generate_local(p, rhs); + LLVMValueRef a = llvm_matrix_column_major_load_from_ptr(p, lhs_ptr); + LLVMValueRef b = llvm_matrix_column_major_load_from_ptr(p, rhs_ptr); + LLVMValueRef c = llvm_matrix_multiply(p, a, b, xt->Matrix.row_count, xt->Matrix.column_count, yt->Matrix.column_count); + + llvm_matrix_column_major_store_to_raw_ptr(p, type, res_ptr, c); + + return lb_addr_load(p, res); + } + +slow_form: + { + Type *elem = xt->Matrix.elem; + + lbAddr res = lb_add_local_generated(p, type, true); + + for (i64 i = 0; i < xt->Matrix.row_count; i++) { + for (i64 j = 0; j < yt->Matrix.column_count; j++) { + for (i64 k = 0; k < xt->Matrix.column_count; k++) { + lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); + + lbValue a = lb_emit_matrix_ev(p, lhs, i, k); + lbValue b = lb_emit_matrix_ev(p, rhs, k, j); + lbValue c = lb_emit_arith(p, Token_Mul, a, b, elem); + lbValue d = lb_emit_load(p, dst); + lbValue e = lb_emit_arith(p, Token_Add, d, c, elem); + lb_emit_store(p, dst, e); + + } + } + } + + return lb_addr_load(p, res); + } +} + lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue rhs, Type *type) { GB_ASSERT(is_type_matrix(lhs.type) || is_type_matrix(rhs.type)); @@ -486,30 +734,7 @@ lbValue lb_emit_arith_matrix(lbProcedure *p, TokenKind op, lbValue lhs, lbValue if (op == Token_Mul) { if (xt->kind == Type_Matrix) { if (yt->kind == Type_Matrix) { - GB_ASSERT(is_type_matrix(type)); - GB_ASSERT(xt->Matrix.column_count == yt->Matrix.row_count); - GB_ASSERT(are_types_identical(xt->Matrix.elem, yt->Matrix.elem)); - - Type *elem = xt->Matrix.elem; - - lbAddr res = lb_add_local_generated(p, type, true); - for (i64 i = 0; i < xt->Matrix.row_count; i++) { - for (i64 j = 0; j < yt->Matrix.column_count; j++) { - for (i64 k = 0; k < xt->Matrix.column_count; k++) { - lbValue dst = lb_emit_matrix_epi(p, res.addr, i, j); - - lbValue a = lb_emit_matrix_ev(p, lhs, i, k); - lbValue b = lb_emit_matrix_ev(p, rhs, k, j); - lbValue c = lb_emit_arith(p, op, a, b, elem); - lbValue d = lb_emit_load(p, dst); - lbValue e = lb_emit_arith(p, Token_Add, d, c, elem); - lb_emit_store(p, dst, e); - - } - } - } - - return lb_addr_load(p, res); + return lb_emit_matrix_mul(p, lhs, rhs, type); } } |