aboutsummaryrefslogtreecommitdiff
path: root/src/llvm_backend_expr.cpp
diff options
context:
space:
mode:
authorgingerBill <bill@gingerbill.org>2021-10-25 01:03:16 +0100
committergingerBill <bill@gingerbill.org>2021-10-25 01:03:16 +0100
commitd62c701a43b255195b1d0dc2f7d80afa40d2b5fe (patch)
tree9b161368e1383ed2ec8bbd5a2a9218932e4a2393 /src/llvm_backend_expr.cpp
parent79ad6f4564e928b166d02c26836f700ea848cb87 (diff)
Improve matrix code generation for all supported platforms
Through assembly optimization
Diffstat (limited to 'src/llvm_backend_expr.cpp')
-rw-r--r--src/llvm_backend_expr.cpp29
1 files changed, 27 insertions, 2 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp
index fa2b0b084..7ae1a7315 100644
--- a/src/llvm_backend_expr.cpp
+++ b/src/llvm_backend_expr.cpp
@@ -489,13 +489,32 @@ bool lb_is_matrix_simdable(Type *t) {
return false;
}
+ switch (build_context.metrics.arch) {
+ case TargetArch_amd64:
+ case TargetArch_arm64:
+ // possible
+ break;
+ case TargetArch_386:
+ case TargetArch_wasm32:
+ // nope
+ return false;
+ }
+
if (elem->kind == Type_Basic) {
switch (elem->Basic.kind) {
case Basic_f16:
case Basic_f16le:
case Basic_f16be:
- // TODO(bill): determine when this is fine
- return true;
+ switch (build_context.metrics.arch) {
+ case TargetArch_amd64:
+ return false;
+ case TargetArch_arm64:
+ // TODO(bill): determine when this is fine
+ return true;
+ case TargetArch_386:
+ case TargetArch_wasm32:
+ return false;
+ }
}
}
@@ -690,6 +709,8 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type)
}
lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+ // TODO(bill): Handle edge case for f16 types on x86(-64) platforms
+
Type *xt = base_type(lhs.type);
Type *yt = base_type(rhs.type);
@@ -775,6 +796,8 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type)
}
lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+ // TODO(bill): Handle edge case for f16 types on x86(-64) platforms
+
Type *mt = base_type(lhs.type);
Type *vt = base_type(rhs.type);
@@ -843,6 +866,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type
}
lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) {
+ // TODO(bill): Handle edge case for f16 types on x86(-64) platforms
+
Type *mt = base_type(rhs.type);
Type *vt = base_type(lhs.type);