diff options
Diffstat (limited to 'src/check_expr.cpp')
| -rw-r--r-- | src/check_expr.cpp | 385 |
1 files changed, 357 insertions, 28 deletions
diff --git a/src/check_expr.cpp b/src/check_expr.cpp index 275210c6c..d8ca190b7 100644 --- a/src/check_expr.cpp +++ b/src/check_expr.cpp @@ -657,6 +657,14 @@ i64 check_distance_between_types(CheckerContext *c, Operand *operand, Type *type return distance + 6; } } + + if (is_type_matrix(dst)) { + Type *elem = base_array_type(dst); + i64 distance = check_distance_between_types(c, operand, elem); + if (distance >= 0) { + return distance + 7; + } + } if (is_type_any(dst)) { if (!is_type_polymorphic(src)) { @@ -897,6 +905,34 @@ void check_assignment(CheckerContext *c, Operand *operand, Type *type, String co } } +bool polymorphic_assign_index(Type **gt_, i64 *dst_count, i64 source_count) { + Type *gt = *gt_; + + GB_ASSERT(gt->kind == Type_Generic); + Entity *e = scope_lookup(gt->Generic.scope, gt->Generic.name); + GB_ASSERT(e != nullptr); + if (e->kind == Entity_TypeName) { + *gt_ = nullptr; + *dst_count = source_count; + + e->kind = Entity_Constant; + e->Constant.value = exact_value_i64(source_count); + e->type = t_untyped_integer; + return true; + } else if (e->kind == Entity_Constant) { + *gt_ = nullptr; + if (e->Constant.value.kind != ExactValue_Integer) { + return false; + } + i64 count = big_int_to_i64(&e->Constant.value.value_integer); + if (count != source_count) { + return false; + } + *dst_count = source_count; + return true; + } + return false; +} bool is_polymorphic_type_assignable(CheckerContext *c, Type *poly, Type *source, bool compound, bool modify_type) { Operand o = {Addressing_Value}; @@ -951,28 +987,7 @@ bool is_polymorphic_type_assignable(CheckerContext *c, Type *poly, Type *source, case Type_Array: if (source->kind == Type_Array) { if (poly->Array.generic_count != nullptr) { - Type *gt = poly->Array.generic_count; - GB_ASSERT(gt->kind == Type_Generic); - Entity *e = scope_lookup(gt->Generic.scope, gt->Generic.name); - GB_ASSERT(e != nullptr); - if (e->kind == Entity_TypeName) { - poly->Array.generic_count = nullptr; - poly->Array.count = source->Array.count; - - e->kind = Entity_Constant; - e->Constant.value = exact_value_i64(source->Array.count); - e->type = t_untyped_integer; - } else if (e->kind == Entity_Constant) { - poly->Array.generic_count = nullptr; - if (e->Constant.value.kind != ExactValue_Integer) { - return false; - } - i64 count = big_int_to_i64(&e->Constant.value.value_integer); - if (count != source->Array.count) { - return false; - } - poly->Array.count = source->Array.count; - } else { + if (!polymorphic_assign_index(&poly->Array.generic_count, &poly->Array.count, source->Array.count)) { return false; } } @@ -1165,6 +1180,27 @@ bool is_polymorphic_type_assignable(CheckerContext *c, Type *poly, Type *source, return key || value; } return false; + + case Type_Matrix: + if (source->kind == Type_Matrix) { + if (poly->Matrix.generic_row_count != nullptr) { + poly->Matrix.stride_in_bytes = 0; + if (!polymorphic_assign_index(&poly->Matrix.generic_row_count, &poly->Matrix.row_count, source->Matrix.row_count)) { + return false; + } + } + if (poly->Matrix.generic_column_count != nullptr) { + poly->Matrix.stride_in_bytes = 0; + if (!polymorphic_assign_index(&poly->Matrix.generic_column_count, &poly->Matrix.column_count, source->Matrix.column_count)) { + return false; + } + } + if (poly->Matrix.row_count == source->Matrix.row_count && + poly->Matrix.column_count == source->Matrix.column_count) { + return is_polymorphic_type_assignable(c, poly->Matrix.elem, source->Matrix.elem, true, modify_type); + } + } + return false; } return false; } @@ -1400,8 +1436,9 @@ bool check_unary_op(CheckerContext *c, Operand *o, Token op) { } bool check_binary_op(CheckerContext *c, Operand *o, Token op) { + Type *main_type = o->type; // TODO(bill): Handle errors correctly - Type *type = base_type(core_array_type(o->type)); + Type *type = base_type(core_array_type(main_type)); Type *ct = core_type(type); switch (op.kind) { case Token_Sub: @@ -1414,10 +1451,15 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { } break; - case Token_Mul: case Token_Quo: - case Token_MulEq: case Token_QuoEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } + /*fallthrough*/ + case Token_Mul: + case Token_MulEq: case Token_AddEq: if (is_type_bit_set(type)) { return true; @@ -1458,6 +1500,10 @@ bool check_binary_op(CheckerContext *c, Operand *o, Token op) { case Token_ModMod: case Token_ModEq: case Token_ModModEq: + if (is_type_matrix(main_type)) { + error(op, "Operator '%.*s' is only allowed with matrix types", LIT(op.string)); + return false; + } if (!is_type_integer(type)) { error(op, "Operator '%.*s' is only allowed with integers", LIT(op.string)); return false; @@ -2414,6 +2460,26 @@ bool check_is_castable_to(CheckerContext *c, Operand *operand, Type *y) { if (is_type_quaternion(src) && is_type_quaternion(dst)) { return true; } + + if (is_type_matrix(src) && is_type_matrix(dst)) { + GB_ASSERT(src->kind == Type_Matrix); + GB_ASSERT(dst->kind == Type_Matrix); + if (!are_types_identical(src->Matrix.elem, dst->Matrix.elem)) { + return false; + } + + if (src->Matrix.row_count != src->Matrix.column_count) { + i64 src_count = src->Matrix.row_count*src->Matrix.column_count; + i64 dst_count = dst->Matrix.row_count*dst->Matrix.column_count; + return src_count == dst_count; + } + + if (dst->Matrix.row_count != dst->Matrix.column_count) { + return false; + } + + return true; + } // Cast between pointers @@ -2670,6 +2736,127 @@ bool can_use_other_type_as_type_hint(bool use_lhs_as_type_hint, Type *other_type return false; } +Type *check_matrix_type_hint(Type *matrix, Type *type_hint) { + Type *xt = base_type(matrix); + if (type_hint != nullptr) { + Type *th = base_type(type_hint); + if (are_types_identical(th, xt)) { + return type_hint; + } else if (xt->kind == Type_Matrix && th->kind == Type_Array) { + if (!are_types_identical(xt->Matrix.elem, th->Array.elem)) { + // ignore + } else if (xt->Matrix.row_count == 1 && xt->Matrix.column_count == th->Array.count) { + return type_hint; + } else if (xt->Matrix.column_count == 1 && xt->Matrix.row_count == th->Array.count) { + return type_hint; + } + } + } + return matrix; +} + + +void check_binary_matrix(CheckerContext *c, Token const &op, Operand *x, Operand *y, Type *type_hint, bool use_lhs_as_type_hint) { + if (!check_binary_op(c, x, op)) { + x->mode = Addressing_Invalid; + return; + } + + Type *xt = base_type(x->type); + Type *yt = base_type(y->type); + + if (is_type_matrix(x->type)) { + GB_ASSERT(xt->kind == Type_Matrix); + if (op.kind == Token_Mul) { + if (yt->kind == Type_Matrix) { + if (!are_types_identical(xt->Matrix.elem, yt->Matrix.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Matrix.row_count) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, yt->Matrix.column_count); + goto matrix_success; + } else if (yt->kind == Type_Array) { + if (!are_types_identical(xt->Matrix.elem, yt->Array.elem)) { + goto matrix_error; + } + + if (xt->Matrix.column_count != yt->Array.count) { + goto matrix_error; + } + + // Treat arrays as column vectors + x->mode = Addressing_Value; + if (type_hint == nullptr && xt->Matrix.row_count == yt->Array.count) { + x->type = y->type; + } else { + x->type = alloc_type_matrix(xt->Matrix.elem, xt->Matrix.row_count, 1); + } + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } else { + GB_ASSERT(is_type_matrix(yt)); + GB_ASSERT(!is_type_matrix(xt)); + + if (op.kind == Token_Mul) { + // NOTE(bill): no need to handle the matrix case here since it should be handled above + if (xt->kind == Type_Array) { + if (!are_types_identical(yt->Matrix.elem, xt->Array.elem)) { + goto matrix_error; + } + + if (xt->Array.count != yt->Matrix.row_count) { + goto matrix_error; + } + + // Treat arrays as row vectors + x->mode = Addressing_Value; + if (type_hint == nullptr && yt->Matrix.column_count == xt->Array.count) { + x->type = x->type; + } else { + x->type = alloc_type_matrix(yt->Matrix.elem, 1, yt->Matrix.column_count); + } + goto matrix_success; + } + } + if (!are_types_identical(xt, yt)) { + goto matrix_error; + } + x->mode = Addressing_Value; + x->type = xt; + goto matrix_success; + } + +matrix_success: + x->type = check_matrix_type_hint(x->type, type_hint); + + return; + + +matrix_error: + gbString xts = type_to_string(x->type); + gbString yts = type_to_string(y->type); + gbString expr_str = expr_to_string(x->expr); + error(op, "Mismatched types in binary matrix expression '%s' for operator '%.*s' : '%s' vs '%s'", expr_str, LIT(op.string), xts, yts); + gb_string_free(expr_str); + gb_string_free(yts); + gb_string_free(xts); + x->type = t_invalid; + x->mode = Addressing_Invalid; + return; + +} + void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint, bool use_lhs_as_type_hint=false) { GB_ASSERT(node->kind == Ast_BinaryExpr); @@ -2874,6 +3061,13 @@ void check_binary_expr(CheckerContext *c, Operand *x, Ast *node, Type *type_hint x->type = y->type; return; } + if (is_type_matrix(x->type) || is_type_matrix(y->type)) { + check_binary_matrix(c, op, x, y, type_hint, use_lhs_as_type_hint); + x->expr = node; + return; + } + + if (!are_types_identical(x->type, y->type)) { if (x->type != t_invalid && y->type != t_invalid) { @@ -3262,6 +3456,29 @@ void convert_to_typed(CheckerContext *c, Operand *operand, Type *target_type) { break; } + + case Type_Matrix: { + Type *elem = base_array_type(t); + if (check_is_assignable_to(c, operand, elem)) { + if (t->Matrix.row_count != t->Matrix.column_count) { + operand->mode = Addressing_Invalid; + begin_error_block(); + defer (end_error_block()); + + convert_untyped_error(c, operand, target_type); + error_line("\tNote: Only a square matrix types can be initialized with a scalar value\n"); + return; + } else { + operand->mode = Addressing_Value; + } + } else { + operand->mode = Addressing_Invalid; + convert_untyped_error(c, operand, target_type); + return; + } + break; + } + case Type_Union: if (!is_operand_nil(*operand) && !is_operand_undef(*operand)) { @@ -6219,6 +6436,16 @@ bool check_set_index_data(Operand *o, Type *t, bool indirection, i64 *max_count, } o->type = t->EnumeratedArray.elem; return true; + + case Type_Matrix: + *max_count = t->Matrix.column_count; + if (indirection) { + o->mode = Addressing_Variable; + } else if (o->mode != Addressing_Variable) { + o->mode = Addressing_Value; + } + o->type = alloc_type_array(t->Matrix.elem, t->Matrix.row_count); + return true; case Type_Slice: o->type = t->Slice.elem; @@ -6517,6 +6744,72 @@ void check_promote_optional_ok(CheckerContext *c, Operand *x, Type **val_type_, } +void check_matrix_index_expr(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) { + ast_node(ie, MatrixIndexExpr, node); + + check_expr(c, o, ie->expr); + node->viral_state_flags |= ie->expr->viral_state_flags; + if (o->mode == Addressing_Invalid) { + o->expr = node; + return; + } + + Type *t = base_type(type_deref(o->type)); + bool is_ptr = is_type_pointer(o->type); + bool is_const = o->mode == Addressing_Constant; + + if (t->kind != Type_Matrix) { + gbString str = expr_to_string(o->expr); + gbString type_str = type_to_string(o->type); + defer (gb_string_free(str)); + defer (gb_string_free(type_str)); + if (is_const) { + error(o->expr, "Cannot use matrix indexing on constant '%s' of type '%s'", str, type_str); + } else { + error(o->expr, "Cannot use matrix indexing on '%s' of type '%s'", str, type_str); + } + o->mode = Addressing_Invalid; + o->expr = node; + return; + } + o->type = t->Matrix.elem; + if (is_ptr) { + o->mode = Addressing_Variable; + } else if (o->mode != Addressing_Variable) { + o->mode = Addressing_Value; + } + + if (ie->row_index == nullptr) { + gbString str = expr_to_string(o->expr); + error(o->expr, "Missing row index for '%s'", str); + gb_string_free(str); + o->mode = Addressing_Invalid; + o->expr = node; + return; + } + if (ie->column_index == nullptr) { + gbString str = expr_to_string(o->expr); + error(o->expr, "Missing column index for '%s'", str); + gb_string_free(str); + o->mode = Addressing_Invalid; + o->expr = node; + return; + } + + i64 row_count = t->Matrix.row_count; + i64 column_count = t->Matrix.column_count; + + i64 row_index = 0; + i64 column_index = 0; + bool row_ok = check_index_value(c, t, false, ie->row_index, row_count, &row_index, nullptr); + bool column_ok = check_index_value(c, t, false, ie->column_index, column_count, &column_index, nullptr); + + + gb_unused(row_ok); + gb_unused(column_ok); +} + + ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type *type_hint) { u32 prev_state_flags = c->state_flags; defer (c->state_flags = prev_state_flags); @@ -7150,6 +7443,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type case Type_Array: case Type_DynamicArray: case Type_SimdVector: + case Type_Matrix: { Type *elem_type = nullptr; String context_name = {}; @@ -7176,6 +7470,10 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type elem_type = t->SimdVector.elem; context_name = str_lit("simd vector literal"); max_type_count = t->SimdVector.count; + } else if (t->kind == Type_Matrix) { + elem_type = t->Matrix.elem; + context_name = str_lit("matrix literal"); + max_type_count = t->Matrix.row_count*t->Matrix.column_count; } else { GB_PANIC("unreachable"); } @@ -8214,6 +8512,8 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type // Okay } else if (is_type_relative_slice(t)) { // Okay + } else if (is_type_matrix(t)) { + // Okay } else { valid = false; } @@ -8278,10 +8578,14 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type } } } + + if (type_hint != nullptr && is_type_matrix(t)) { + // TODO(bill): allow matrix columns to be assignable to other types which are the same internally + // if a type hint exists + } + case_end; - - case_ast_node(se, SliceExpr, node); check_expr(c, o, se->expr); node->viral_state_flags |= se->expr->viral_state_flags; @@ -8454,7 +8758,12 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type } case_end; - + + case_ast_node(mie, MatrixIndexExpr, node); + check_matrix_index_expr(c, o, node, type_hint); + o->expr = node; + return Expr_Expr; + case_end; case_ast_node(ce, CallExpr, node); return check_call_expr(c, o, node, ce->proc, ce->args, ce->inlining, type_hint); @@ -8561,6 +8870,7 @@ ExprKind check_expr_base_internal(CheckerContext *c, Operand *o, Ast *node, Type case Ast_EnumType: case Ast_MapType: case Ast_BitSetType: + case Ast_MatrixType: o->mode = Addressing_Type; o->type = check_type(c, node); break; @@ -8964,6 +9274,15 @@ gbString write_expr_to_string(gbString str, Ast *node, bool shorthand) { str = gb_string_append_rune(str, ']'); case_end; + case_ast_node(mie, MatrixIndexExpr, node); + str = write_expr_to_string(str, mie->expr, shorthand); + str = gb_string_append_rune(str, '['); + str = write_expr_to_string(str, mie->row_index, shorthand); + str = gb_string_appendc(str, ", "); + str = write_expr_to_string(str, mie->column_index, shorthand); + str = gb_string_append_rune(str, ']'); + case_end; + case_ast_node(e, Ellipsis, node); str = gb_string_appendc(str, ".."); str = write_expr_to_string(str, e->expr, shorthand); @@ -9035,6 +9354,16 @@ gbString write_expr_to_string(gbString str, Ast *node, bool shorthand) { str = gb_string_append_rune(str, ']'); str = write_expr_to_string(str, mt->value, shorthand); case_end; + + case_ast_node(mt, MatrixType, node); + str = gb_string_appendc(str, "matrix["); + str = write_expr_to_string(str, mt->row_count, shorthand); + str = gb_string_appendc(str, ", "); + str = write_expr_to_string(str, mt->column_count, shorthand); + str = gb_string_append_rune(str, ']'); + str = write_expr_to_string(str, mt->elem, shorthand); + case_end; + case_ast_node(f, Field, node); if (f->flags&FieldFlag_using) { |