diff options
| author | gingerBill <bill@gingerbill.org> | 2021-10-19 11:24:26 +0100 |
|---|---|---|
| committer | gingerBill <bill@gingerbill.org> | 2021-10-19 11:24:26 +0100 |
| commit | 243e2e2b8a7566087375178a66b25b5d9ac9a356 (patch) | |
| tree | d82f6499edb8a1056c19479c59e6390f294887e7 /src/check_expr.cpp | |
| parent | 35111b39b88bb12d61e1dc67ed0161109be3f865 (diff) | |
Basic support for matrix*vector, vector*matrix operations
Diffstat (limited to 'src/check_expr.cpp')
| -rw-r--r-- | src/check_expr.cpp | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 9c12802d7..1ca5b895d 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -2686,10 +2686,11 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand x->mode = Addressing_Invalid; return; } + + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); if (is_type_matrix(x->type)) { - Type *xt = base_type(x->type); - Type *yt = base_type(y->type); GB_ASSERT(xt->kind == Type_Matrix); if (op.kind == Token_Mul) { if (yt->kind == Type_Matrix) { @@ -2714,7 +2715,11 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand // Treat arrays as column vectors x->mode = Addressing_Value; - x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1); + if (type_hint == nullptr && xt->Matrix.row_count == yt->Array.count) { + x->type = y->type; + } else { + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1); + } goto matrix_success; } } @@ -2725,8 +2730,6 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand x->type = xt; goto matrix_success; } else { - Type *xt = base_type(x->type); - Type *yt = base_type(y->type); GB_ASSERT(is_type_matrix(yt)); GB_ASSERT(!is_type_matrix(xt)); @@ -2743,7 +2746,11 @@ void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand // Treat arrays as row vectors x->mode = Addressing_Value; - x->type = alloc_type_matrix(xt->Matrix.elem, 1, xt->Matrix.column_count); + if (type_hint == nullptr && yt->Matrix.column_count == xt->Array.count) { + x->type = x->type; + } else { + x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count); + } goto matrix_success; } } @@ -2775,13 +2782,13 @@ matrix_success: matrix_error: - gbString xt = type_to_string(x->type); - gbString yt = type_to_string(y->type); + gbString xts = type_to_string(x->type); + gbString yts = type_to_string(y->type); gbString expr_str = expr_to_string(x->expr); - error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xt, yt); + error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xts, yts); gb_string_free(expr_str); - gb_string_free(yt); - gb_string_free(xt); + gb_string_free(yts); + gb_string_free(xts); x->type = t_invalid; x->mode = Addressing_Invalid; return; @@ -2994,6 +3001,7 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint } if (is_type_matrix(x->type) || is_type_matrix(y->type)) { check_binary_matrix(c, op, x, y, type_hint, use_lhs_as_type_hint); + x->expr = node; return; } |