From d0d9a3a4f4f3b4bc528c73ffcecb31d3eb4162a7 Mon Sep 17 00:00:00 2001 From: gingerBill Date: Wed, 20 Oct 2021 14:49:20 +0100 Subject: Make `lb_emit_matrix_mul` SIMD if possible --- src/llvm_backend_utility.cpp | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) (limited to 'src/llvm_backend_utility.cpp') diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index b07dc3459..6754ce798 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1577,7 +1577,7 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { GB_ASSERT_MSG(id != 0, "Unable to find %s", name); LLVMTypeRef types[1] = {}; - types[0] = elem; + types[0] = type; LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types)); LLVMValueRef values[2] = {}; @@ -1585,4 +1585,31 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { values[1] = value; LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, ""); return call; +} + +LLVMValueRef llvm_vector_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { + GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b)); + + LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a)); + + if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) { + return LLVMBuildAdd(p->builder, a, b, ""); + } + return LLVMBuildFAdd(p->builder, a, b, ""); +} + +LLVMValueRef llvm_vector_mul(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { + GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b)); + + LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a)); + + if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) { + return LLVMBuildMul(p->builder, a, b, ""); + } + return LLVMBuildFMul(p->builder, a, b, ""); +} + + +LLVMValueRef llvm_vector_dot(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { + return llvm_vector_reduce_add(p, llvm_vector_mul(p, a, b)); } \ No newline at end of file -- cgit v1.2.3