diff options
| author | gingerBill <bill@gingerbill.org> | 2021-10-28 15:01:13 +0100 |
|---|---|---|
| committer | gingerBill <bill@gingerbill.org> | 2021-10-28 15:01:13 +0100 |
| commit | 3794d2417db51ba1e50fa9f3e0d7df973328e6b0 (patch) | |
| tree | 50e9b1ce0626bcabb240a551aa6ab5568feec679 /src | |
| parent | 70793236abc278dd51ca577b35ca1757851380d3 (diff) | |
Write a `log(n)` fallback for `llvm_vector_reduce_add`
This may be what LLVM does at any rate
Diffstat (limited to 'src')
| -rw-r--r-- | src/common.cpp | 12 | ||||
| -rw-r--r-- | src/llvm_backend_utility.cpp | 66 |
2 files changed, 72 insertions, 6 deletions
diff --git a/src/common.cpp b/src/common.cpp index bebae6ab3..7af7026b9 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -443,7 +443,17 @@ u64 ceil_log2(u64 x) { return cast(u64)(bit_set_count(x) - 1 - y); } - +u32 prev_pow2(u32 n) { + if (n == 0) { + return 0; + } + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n - (n >> 1); +} i32 prev_pow2(i32 n) { if (n <= 0) { return 0; diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp index eccf01319..8ba23089a 100644 --- a/src/llvm_backend_utility.cpp +++ b/src/llvm_backend_utility.cpp @@ -1563,6 +1563,40 @@ LLVMValueRef llvm_vector_broadcast(lbProcedure *p, LLVMValueRef value, unsigned return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, ""); } +LLVMValueRef llvm_vector_shuffle_reduction(lbProcedure *p, LLVMValueRef value, LLVMOpcode op_code) { + LLVMValueRef v_zero32 = lb_const_int(p->module, t_u32, 0).value; + unsigned len = LLVMGetVectorSize(LLVMTypeOf(value)); + if (len == 1) { + return LLVMBuildExtractElement(p->builder, value, v_zero32, ""); + } + GB_ASSERT((len & (len-1)) == 0); + + for (unsigned i = len; i != 1; i >>= 1) { + LLVMValueRef lhs_mask = llvm_mask_iota(p->module, 0, i/2); + LLVMValueRef rhs_mask = llvm_mask_iota(p->module, i/2, i); + LLVMValueRef lhs = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), lhs_mask, ""); + LLVMValueRef rhs = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), rhs_mask, ""); + + value = LLVMBuildBinOp(p->builder, op_code, lhs, rhs, ""); + } + return LLVMBuildExtractElement(p->builder, value, v_zero32, ""); +} + +LLVMValueRef llvm_vector_expand_to_power_of_two(lbProcedure *p, LLVMValueRef value) { + LLVMTypeRef vector_type = LLVMTypeOf(value); + unsigned len = LLVMGetVectorSize(vector_type); + if (len == 1) { + return value; + } + if ((len & (len-1)) == 0) { + return value; + } + + unsigned expanded_len = cast(unsigned)next_pow2(cast(i64)len); + LLVMValueRef mask = llvm_mask_iota(p->module, 0, expanded_len); + return LLVMBuildShuffleVector(p->builder, value, LLVMConstNull(vector_type), mask, ""); +} + LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { LLVMTypeRef type = LLVMTypeOf(value); GB_ASSERT(LLVMGetTypeKind(type) == LLVMVectorTypeKind); @@ -1571,11 +1605,11 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { if (len == 0) { return LLVMConstNull(type); } - + char const *name = nullptr; i32 value_offset = 0; i32 value_count = 0; - + switch (LLVMGetTypeKind(elem)) { case LLVMHalfTypeKind: case LLVMFloatTypeKind: @@ -1593,7 +1627,7 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { GB_PANIC("invalid vector type %s", LLVMPrintTypeToString(type)); break; } - + unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name)); if (id != 0) { LLVMTypeRef types[1] = {}; @@ -1606,9 +1640,9 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, ""); return call; } - + // Manual reduce - +#if 0 LLVMValueRef sum = LLVMBuildExtractElement(p->builder, value, lb_const_int(p->module, t_u32, 0).value, ""); for (unsigned i = 0; i < len; i++) { LLVMValueRef val = LLVMBuildExtractElement(p->builder, value, lb_const_int(p->module, t_u32, i).value, ""); @@ -1619,6 +1653,28 @@ LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) { } } return sum; +#else + LLVMOpcode op_code = LLVMFAdd; + if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) { + op_code = LLVMAdd; + } + + unsigned len_pow_2 = prev_pow2(len); + if (len_pow_2 == len) { + return llvm_vector_shuffle_reduction(p, value, op_code); + } else { + LLVMValueRef lower_mask = llvm_mask_iota(p->module, 0, len_pow_2); + LLVMValueRef upper_mask = llvm_mask_iota(p->module, len_pow_2, len-len_pow_2); + LLVMValueRef lower = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), lower_mask, ""); + LLVMValueRef upper = LLVMBuildShuffleVector(p->builder, value, LLVMGetUndef(LLVMTypeOf(value)), upper_mask, ""); + upper = llvm_vector_expand_to_power_of_two(p, upper); + + LLVMValueRef lower_reduced = llvm_vector_shuffle_reduction(p, lower, op_code); + LLVMValueRef upper_reduced = llvm_vector_shuffle_reduction(p, upper, op_code); + + return LLVMBuildBinOp(p->builder, op_code, lower_reduced, upper_reduced, ""); + } +#endif } LLVMValueRef llvm_vector_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) { |