diff options
| author | gingerBill <bill@gingerbill.org> | 2021-10-25 01:03:16 +0100 |
|---|---|---|
| committer | gingerBill <bill@gingerbill.org> | 2021-10-25 01:03:16 +0100 |
| commit | d62c701a43b255195b1d0dc2f7d80afa40d2b5fe (patch) | |
| tree | 9b161368e1383ed2ec8bbd5a2a9218932e4a2393 /src/llvm_backend_expr.cpp | |
| parent | 79ad6f4564e928b166d02c26836f700ea848cb87 (diff) | |
Improve matrix code generation for all supported platforms
Through assembly optimization
Diffstat (limited to 'src/llvm_backend_expr.cpp')
| -rw-r--r-- | src/llvm_backend_expr.cpp | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/src/llvm_backend_expr.cpp b/src/llvm_backend_expr.cpp index fa2b0b084..7ae1a7315 100644 --- a/src/llvm_backend_expr.cpp +++ b/src/llvm_backend_expr.cpp @@ -489,13 +489,32 @@ bool lb_is_matrix_simdable(Type *t) { return false; } + switch (build_context.metrics.arch) { + case TargetArch_amd64: + case TargetArch_arm64: + // possible + break; + case TargetArch_386: + case TargetArch_wasm32: + // nope + return false; + } + if (elem->kind == Type_Basic) { switch (elem->Basic.kind) { case Basic_f16: case Basic_f16le: case Basic_f16be: - // TODO(bill): determine when this is fine - return true; + switch (build_context.metrics.arch) { + case TargetArch_amd64: + return false; + case TargetArch_arm64: + // TODO(bill): determine when this is fine + return true; + case TargetArch_386: + case TargetArch_wasm32: + return false; + } } } @@ -690,6 +709,8 @@ lbValue lb_emit_outer_product(lbProcedure *p, lbValue a, lbValue b, Type *type) } lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { + // TODO(bill): Handle edge case for f16 types on x86(-64) platforms + Type *xt = base_type(lhs.type); Type *yt = base_type(rhs.type); @@ -775,6 +796,8 @@ lbValue lb_emit_matrix_mul(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) } lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { + // TODO(bill): Handle edge case for f16 types on x86(-64) platforms + Type *mt = base_type(lhs.type); Type *vt = base_type(rhs.type); @@ -843,6 +866,8 @@ lbValue lb_emit_matrix_mul_vector(lbProcedure *p, lbValue lhs, lbValue rhs, Type } lbValue lb_emit_vector_mul_matrix(lbProcedure *p, lbValue lhs, lbValue rhs, Type *type) { + // TODO(bill): Handle edge case for f16 types on x86(-64) platforms + Type *mt = base_type(rhs.type); Type *vt = base_type(lhs.type); |