diff options
| author | gingerBill <bill@gingerbill.org> | 2021-10-18 18:16:52 +0100 |
|---|---|---|
| committer | gingerBill <bill@gingerbill.org> | 2021-10-18 18:16:52 +0100 |
| commit | ba331024af2f5074125442e91dda6c8e63324c8f (patch) | |
| tree | 86545465efba2d2fb0221c20e83e881efea5c96e /src/check_expr.cpp | |
| parent | 4c655865e5d9af83a98c137609b01972f4e51beb (diff) | |
Very basic matrix support in backend
Diffstat (limited to 'src/check_expr.cpp')
| -rw-r--r-- | src/check_expr.cpp | 153 |
1 files changed, 150 insertions, 3 deletions
diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 85f2eeb23..9c12802d7 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -1400,8 +1400,9 @@ bool check_unary_op(CheckerContext *c, Operand *o, Token op) { } bool check_binary_op(CheckerContext *c, Operand *o, Token op) { + Type *main_type = o->type; // TODO(bill): Handle errors correctly - Type *type = base_type(core_array_type(o->type)); + Type *type = base_type(core_array_type(main_type)); Type *ct = core_type(type); switch (op.kind) { case Token_Sub: @@ -1414,10 +1415,15 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { } break; - case Token_Mul: case Token_Quo: - case Token_MulEq: case Token_QuoEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } + /*fallthrough*/ + case Token_Mul: + case Token_MulEq: case Token_AddEq: if (is_type_bit_set(type)) { return true; @@ -1458,6 +1464,10 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { case Token_ModMod: case Token_ModEq: case Token_ModModEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } if (!is_type_integer(type)) { error(op, "Operator '%.*s' is only allowed with integers", LIT(op.string)); return false; @@ -2671,6 +2681,114 @@ bool can_use_other_type_as_type_hint(bool use_lhs_as_type_hint, Type *other_type } +void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand *y, Type *type_hint, bool use_lhs_as_type_hint) { + if (!check_binary_op(c, x, op)) { + x->mode = Addressing_Invalid; + return; + } + + if (is_type_matrix(x->type)) { + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); + GB_ASSERT(xt->kind == Type_Matrix); + if (op.kind == Token_Mul) { + if (yt->kind == Type_Matrix) { + if (!are_types_identical(xt->Matrix.elem, yt->Matrix.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Matrix.row_count) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count); + goto matrix_success; + } else if (yt->kind == Type_Array) { + if (!are_types_identical(xt->Matrix.elem, yt->Array.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Array.count) { + goto matrix_error; + } + + // Treat arrays as column vectors + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1); + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } else { + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); + GB_ASSERT(is_type_matrix(yt)); + GB_ASSERT(!is_type_matrix(xt)); + + if (op.kind == Token_Mul) { + // NOTE(bill): no need to handle the matrix case here since it should be handled above + if (xt->kind == Type_Array) { + if (!are_types_identical(yt->Matrix.elem, xt->Array.elem)) { + goto matrix_error; + } + + if (xt->Array.count != yt->Matrix.row_count) { + goto matrix_error; + } + + // Treat arrays as row vectors + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, 1, xt->Matrix.column_count); + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } + +matrix_success: + if (type_hint != nullptr) { + Type *th = base_type(type_hint); + if (are_types_identical(th, x->type)) { + x->type = type_hint; + } else if (x->type->kind == Type_Matrix && th->kind == Type_Array) { + Type *xt = x->type; + if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) { + // ignore + } else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) { + x->type = type_hint; + } else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) { + x->type = type_hint; + } + } + } + return; + + +matrix_error: + gbString xt = type_to_string(x->type); + gbString yt = type_to_string(y->type); + gbString expr_str = expr_to_string(x->expr); + error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xt, yt); + gb_string_free(expr_str); + gb_string_free(yt); + gb_string_free(xt); + x->type = t_invalid; + x->mode = Addressing_Invalid; + return; + +} + + void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint, bool use_lhs_as_type_hint=false) { GB_ASSERT(node->kind == Ast_BinaryExpr); Operand y_ = {}, *y = &y_; @@ -2874,6 +2992,12 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint x->type = y->type; return; } + if (is_type_matrix(x->type) || is_type_matrix(y->type)) { + check_binary_matrix(c, op, x, y, type_hint, use_lhs_as_type_hint); + return; + } + + if (!are_types_identical(x->type, y->type)) { if (x->type != t_invalid && y->type != t_invalid) { @@ -3258,6 +3382,29 @@ void convert_to_typed(CheckerContext *c, Operand *operand, Type *target_type) { break; } + + case Type_Matrix: { + Type *elem = base_array_type(t); + if (check_is_assignable_to(c, operand, elem)) { + if (t->Matrix.row_count != t->Matrix.column_count) { + operand->mode = Addressing_Invalid; + begin_error_block(); + defer (end_error_block()); + + convert_untyped_error(c, operand, target_type); + error_line("\tNote: Only a square matrix types can be initialized with a scalar value\n"); + return; + } else { + operand->mode = Addressing_Value; + } + } else { + operand->mode = Addressing_Invalid; + convert_untyped_error(c, operand, target_type); + return; + } + break; + } + case Type_Union: if (!is_operand_nil(*operand) && !is_operand_undef(*operand)) { |