aboutsummaryrefslogtreecommitdiff
path: root/src/llvm_backend_utility.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llvm_backend_utility.cpp')
-rw-r--r--src/llvm_backend_utility.cpp296
1 files changed, 295 insertions, 1 deletions
diff --git a/src/llvm_backend_utility.cpp b/src/llvm_backend_utility.cpp
index 948180f30..e2249171c 100644
--- a/src/llvm_backend_utility.cpp
+++ b/src/llvm_backend_utility.cpp
@@ -1221,6 +1221,109 @@ lbValue lb_emit_ptr_offset(lbProcedure *p, lbValue ptr, lbValue index) {
return res;
}
+lbValue lb_emit_matrix_epi(lbProcedure *p, lbValue s, isize row, isize column) {
+ Type *t = s.type;
+ GB_ASSERT(is_type_pointer(t));
+ Type *mt = base_type(type_deref(t));
+
+ Type *ptr = base_array_type(mt);
+
+ if (column == 0) {
+ GB_ASSERT_MSG(is_type_matrix(mt) || is_type_array_like(mt), "%s", type_to_string(mt));
+
+ LLVMValueRef indices[2] = {
+ LLVMConstInt(lb_type(p->module, t_int), 0, false),
+ LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)row, false),
+ };
+
+ lbValue res = {};
+ if (lb_is_const(s)) {
+ res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices));
+ } else {
+ res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), "");
+ }
+
+ Type *ptr = base_array_type(mt);
+ res.type = alloc_type_pointer(ptr);
+ return res;
+ } else if (row == 0 && is_type_array_like(mt)) {
+ LLVMValueRef indices[2] = {
+ LLVMConstInt(lb_type(p->module, t_int), 0, false),
+ LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)column, false),
+ };
+
+ lbValue res = {};
+ if (lb_is_const(s)) {
+ res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices));
+ } else {
+ res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), "");
+ }
+
+ Type *ptr = base_array_type(mt);
+ res.type = alloc_type_pointer(ptr);
+ return res;
+ }
+
+
+ GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt));
+
+ isize offset = matrix_indices_to_offset(mt, row, column);
+
+ LLVMValueRef indices[2] = {
+ LLVMConstInt(lb_type(p->module, t_int), 0, false),
+ LLVMConstInt(lb_type(p->module, t_int), cast(unsigned)offset, false),
+ };
+
+ lbValue res = {};
+ if (lb_is_const(s)) {
+ res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices));
+ } else {
+ res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), "");
+ }
+ res.type = alloc_type_pointer(ptr);
+ return res;
+}
+
+lbValue lb_emit_matrix_ep(lbProcedure *p, lbValue s, lbValue row, lbValue column) {
+ Type *t = s.type;
+ GB_ASSERT(is_type_pointer(t));
+ Type *mt = base_type(type_deref(t));
+ GB_ASSERT_MSG(is_type_matrix(mt), "%s", type_to_string(mt));
+
+ Type *ptr = base_array_type(mt);
+
+ LLVMValueRef stride_elems = lb_const_int(p->module, t_int, matrix_type_stride_in_elems(mt)).value;
+
+ row = lb_emit_conv(p, row, t_int);
+ column = lb_emit_conv(p, column, t_int);
+
+ LLVMValueRef index = LLVMBuildAdd(p->builder, row.value, LLVMBuildMul(p->builder, column.value, stride_elems, ""), "");
+
+ LLVMValueRef indices[2] = {
+ LLVMConstInt(lb_type(p->module, t_int), 0, false),
+ index,
+ };
+
+ lbValue res = {};
+ if (lb_is_const(s)) {
+ res.value = LLVMConstGEP(s.value, indices, gb_count_of(indices));
+ } else {
+ res.value = LLVMBuildGEP(p->builder, s.value, indices, gb_count_of(indices), "");
+ }
+ res.type = alloc_type_pointer(ptr);
+ return res;
+}
+
+
+lbValue lb_emit_matrix_ev(lbProcedure *p, lbValue s, isize row, isize column) {
+ Type *st = base_type(s.type);
+ GB_ASSERT_MSG(is_type_matrix(st), "%s", type_to_string(st));
+
+ lbValue value = lb_address_from_load_or_generate_local(p, s);
+ lbValue ptr = lb_emit_matrix_epi(p, value, row, column);
+ return lb_emit_load(p, ptr);
+}
+
void lb_fill_slice(lbProcedure *p, lbAddr const &slice, lbValue base_elem, lbValue len) {
Type *t = lb_addr_type(slice);
@@ -1380,6 +1483,198 @@ lbValue lb_soa_struct_cap(lbProcedure *p, lbValue value) {
return lb_emit_struct_ev(p, value, cast(i32)n);
}
+lbValue lb_emit_mul_add(lbProcedure *p, lbValue a, lbValue b, lbValue c, Type *t) {
+ lbModule *m = p->module;
+
+ a = lb_emit_conv(p, a, t);
+ b = lb_emit_conv(p, b, t);
+ c = lb_emit_conv(p, c, t);
+
+ bool is_possible = !is_type_different_to_arch_endianness(t) && is_type_float(t);
+
+ if (is_possible) {
+ switch (build_context.metrics.arch) {
+ case TargetArch_amd64:
+ if (type_size_of(t) == 2) {
+ is_possible = false;
+ }
+ break;
+ case TargetArch_arm64:
+ // possible
+ break;
+ case TargetArch_386:
+ case TargetArch_wasm32:
+ is_possible = false;
+ break;
+ }
+ }
+
+ if (is_possible) {
+ char const *name = "llvm.fma";
+ unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+ GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+
+ LLVMTypeRef types[1] = {};
+ types[0] = lb_type(m, t);
+
+ LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+ LLVMValueRef values[3] = {};
+ values[0] = a.value;
+ values[1] = b.value;
+ values[2] = c.value;
+ LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+ return {call, t};
+ } else {
+ lbValue x = lb_emit_arith(p, Token_Mul, a, b, t);
+ lbValue y = lb_emit_arith(p, Token_Add, x, c, t);
+ return y;
+ }
+}
+
+LLVMValueRef llvm_mask_iota(lbModule *m, unsigned start, unsigned count) {
+ auto iota = slice_make<LLVMValueRef>(temporary_allocator(), count);
+ for (unsigned i = 0; i < count; i++) {
+ iota[i] = lb_const_int(m, t_u32, start+i).value;
+ }
+ return LLVMConstVector(iota.data, count);
+}
+
+LLVMValueRef llvm_mask_zero(lbModule *m, unsigned count) {
+ return LLVMConstNull(LLVMVectorType(lb_type(m, t_u32), count));
+}
+
+LLVMValueRef llvm_vector_broadcast(lbProcedure *p, LLVMValueRef value, unsigned count) {
+ GB_ASSERT(count > 0);
+ if (LLVMIsConstant(value)) {
+ LLVMValueRef single = LLVMConstVector(&value, 1);
+ if (count == 1) {
+ return single;
+ }
+ LLVMValueRef mask = llvm_mask_zero(p->module, count);
+ return LLVMConstShuffleVector(single, LLVMGetUndef(LLVMTypeOf(single)), mask);
+ }
+
+ LLVMTypeRef single_type = LLVMVectorType(LLVMTypeOf(value), 1);
+ LLVMValueRef single = LLVMBuildBitCast(p->builder, value, single_type, "");
+ if (count == 1) {
+ return single;
+ }
+ LLVMValueRef mask = llvm_mask_zero(p->module, count);
+ return LLVMBuildShuffleVector(p->builder, single, LLVMGetUndef(LLVMTypeOf(single)), mask, "");
+}
+
+LLVMValueRef llvm_vector_reduce_add(lbProcedure *p, LLVMValueRef value) {
+ LLVMTypeRef type = LLVMTypeOf(value);
+ GB_ASSERT(LLVMGetTypeKind(type) == LLVMVectorTypeKind);
+ LLVMTypeRef elem = LLVMGetElementType(type);
+
+ char const *name = nullptr;
+ i32 value_offset = 0;
+ i32 value_count = 0;
+
+ switch (LLVMGetTypeKind(elem)) {
+ case LLVMHalfTypeKind:
+ case LLVMFloatTypeKind:
+ case LLVMDoubleTypeKind:
+ name = "llvm.vector.reduce.fadd";
+ value_offset = 0;
+ value_count = 2;
+ break;
+ case LLVMIntegerTypeKind:
+ name = "llvm.vector.reduce.add";
+ value_offset = 1;
+ value_count = 1;
+ break;
+ default:
+ GB_PANIC("invalid vector type %s", LLVMPrintTypeToString(type));
+ break;
+ }
+
+ unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+ GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+
+ LLVMTypeRef types[1] = {};
+ types[0] = type;
+
+ LLVMValueRef ip = LLVMGetIntrinsicDeclaration(p->module->mod, id, types, gb_count_of(types));
+ LLVMValueRef values[2] = {};
+ values[0] = LLVMConstNull(elem);
+ values[1] = value;
+ LLVMValueRef call = LLVMBuildCall(p->builder, ip, values+value_offset, value_count, "");
+ return call;
+}
+
+LLVMValueRef llvm_vector_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
+ GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b));
+
+ LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a));
+
+ if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) {
+ return LLVMBuildAdd(p->builder, a, b, "");
+ }
+ return LLVMBuildFAdd(p->builder, a, b, "");
+}
+
+LLVMValueRef llvm_vector_mul(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
+ GB_ASSERT(LLVMTypeOf(a) == LLVMTypeOf(b));
+
+ LLVMTypeRef elem = LLVMGetElementType(LLVMTypeOf(a));
+
+ if (LLVMGetTypeKind(elem) == LLVMIntegerTypeKind) {
+ return LLVMBuildMul(p->builder, a, b, "");
+ }
+ return LLVMBuildFMul(p->builder, a, b, "");
+}
+
+
+LLVMValueRef llvm_vector_dot(lbProcedure *p, LLVMValueRef a, LLVMValueRef b) {
+ return llvm_vector_reduce_add(p, llvm_vector_mul(p, a, b));
+}
+
+LLVMValueRef llvm_vector_mul_add(lbProcedure *p, LLVMValueRef a, LLVMValueRef b, LLVMValueRef c) {
+ lbModule *m = p->module;
+
+ LLVMTypeRef t = LLVMTypeOf(a);
+ GB_ASSERT(t == LLVMTypeOf(b));
+ GB_ASSERT(t == LLVMTypeOf(c));
+ GB_ASSERT(LLVMGetTypeKind(t) == LLVMVectorTypeKind);
+
+ LLVMTypeRef elem = LLVMGetElementType(t);
+
+ bool is_possible = false;
+
+ switch (LLVMGetTypeKind(elem)) {
+ case LLVMHalfTypeKind:
+ is_possible = true;
+ break;
+ case LLVMFloatTypeKind:
+ case LLVMDoubleTypeKind:
+ is_possible = true;
+ break;
+ }
+
+ if (is_possible) {
+ char const *name = "llvm.fmuladd";
+ unsigned id = LLVMLookupIntrinsicID(name, gb_strlen(name));
+ GB_ASSERT_MSG(id != 0, "Unable to find %s", name);
+
+ LLVMTypeRef types[1] = {};
+ types[0] = t;
+
+ LLVMValueRef ip = LLVMGetIntrinsicDeclaration(m->mod, id, types, gb_count_of(types));
+ LLVMValueRef values[3] = {};
+ values[0] = a;
+ values[1] = b;
+ values[2] = c;
+ LLVMValueRef call = LLVMBuildCall(p->builder, ip, values, gb_count_of(values), "");
+ return call;
+ } else {
+ LLVMValueRef x = llvm_vector_mul(p, a, b);
+ LLVMValueRef y = llvm_vector_add(p, x, c);
+ return y;
+ }
+}
+
LLVMValueRef llvm_get_inline_asm(LLVMTypeRef func_type, String const &str, String const &clobbers, bool has_side_effects=true, bool is_align_stack=false, LLVMInlineAsmDialect dialect=LLVMInlineAsmDialectATT) {
return LLVMGetInlineAsm(func_type,
cast(char *)str.text, cast(size_t)str.len,
@@ -1391,4 +1686,3 @@ LLVMValueRef llvm_get_inline_asm(LLVMTypeRef func_type, String const &str, Strin
#endif
);
}
-