diff options
| author | DanielGavin <danielgavin5@hotmail.com> | 2025-07-25 13:22:13 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-25 13:22:13 +0200 |
| commit | 812a065324e8d8ea75fcab4bf881d8574306e1b2 (patch) | |
| tree | 11ca9d3d470c10559b34fcf6fac39feb3d5916f2 /src/server | |
| parent | 0a5488bb05dbc169f1b7761fd6872d3b4aba4422 (diff) | |
| parent | 720938e13e8ef433eecae552f471c91e2253044b (diff) | |
Merge pull request #780 from BradLewis/feat/resolve-builtin-procs
Resolve builtin max, min, clamp and abs return types correctly
Diffstat (limited to 'src/server')
| -rw-r--r-- | src/server/analysis.odin | 193 |
1 files changed, 170 insertions, 23 deletions
diff --git a/src/server/analysis.odin b/src/server/analysis.odin index 4b0134a..91fa2c1 100644 --- a/src/server/analysis.odin +++ b/src/server/analysis.odin @@ -934,6 +934,157 @@ resolve_basic_directive :: proc( return {}, false } +// Gets the return type of the proc. +// Requires the underlying call expression to handle some builtin procs +get_proc_return_types :: proc( + ast_context: ^AstContext, symbol: Symbol, call: ^ast.Call_Expr, is_mutable: bool, +) -> []^ast.Expr { + return_types := make([dynamic]^ast.Expr, context.temp_allocator) + if ret, ok := check_builtin_proc_return_type(symbol, call, is_mutable); ok { + if call, ok := ret.derived.(^ast.Call_Expr); ok { + if symbol, ok := internal_resolve_type_expression(ast_context, call.expr); ok { + return get_proc_return_types(ast_context, symbol, call, true) + } + } + append(&return_types, ret) + } else if v, ok := symbol.value.(SymbolProcedureValue); ok { + for ret in v.return_types { + if ret.type != nil { + append(&return_types, ret.type) + } else if ret.default_value != nil { + append(&return_types, ret.default_value) + } + } + } + + return return_types[:] +} + +// Attempts to resolve the type of the builtin proc by following the rules of the odin type checker +// defined in `check_builtin.cpp`. +// We don't need to worry about whether the inputs to the procs are valid which eliminates most edge cases. +// The basic rules are as follows: +// - For values not known at compile time (eg values return from procs), just return that type. +// The correct value will either be that type or a compiler error. +// - If all values are known at compile time, then we essentially compute the relevant value +// and return that type. +// There is a difference in the returned types between constants and variables. Constants will use an untyped +// value whereas variables will be typed (as either `int` or `f64`). +check_builtin_proc_return_type :: proc(symbol: Symbol, call: ^ast.Call_Expr, is_mutable: bool) -> (^ast.Expr, bool) { + convert_candidate :: proc(candidate: ^ast.Basic_Lit, is_mutable: bool) -> ^ast.Expr { + if is_mutable { + ident := ast.new(ast.Ident, candidate.pos, candidate.end) + if candidate.tok.kind == .Integer { + ident.name = "int" + } else { + ident.name = "f64" + } + return ident + } + + return candidate + } + + compare_basic_lit_value :: proc(a, b: f64, name: string) -> bool { + if name == "max" { + return a > b + } else if name == "min" { + return a < b + } + return a > b + } + + get_basic_lit_value :: proc(n: ^ast.Expr) -> (^ast.Basic_Lit, f64, bool) { + n := n + + op := "" + if u, ok := n.derived.(^ast.Unary_Expr); ok { + op = u.op.text + n = u.expr + } + + if lit, ok := n.derived.(^ast.Basic_Lit); ok { + text := lit.tok.text + if op != "" { + text = fmt.tprintf("%s%s", op, text) + } + value, ok := strconv.parse_f64(text) + if !ok { + return nil, 0, false + } + + return lit, value, true + } + + return nil, 0, false + } + + if symbol.pkg == "$builtin" { + switch symbol.name { + case "max", "min": + curr_candidate: ^ast.Basic_Lit + curr_value := 0.0 + for arg, i in call.args { + if lit, value, ok := get_basic_lit_value(arg); ok { + if i != 0 { + if compare_basic_lit_value(value, curr_value, symbol.name) { + curr_candidate = lit + curr_value = value + } + } else { + curr_candidate = lit + curr_value = value + } + } + if lit, ok := arg.derived.(^ast.Basic_Lit); ok { + } else { + return arg, true + } + } + if curr_candidate != nil { + return convert_candidate(curr_candidate, is_mutable), true + } + case "abs": + for arg in call.args { + if lit, _, ok := get_basic_lit_value(arg); ok { + return convert_candidate(lit, is_mutable), true + } + return arg, true + } + case "clamp": + if len(call.args) == 3 { + + value_lit, value_value, value_ok := get_basic_lit_value(call.args[0]) + if !value_ok { + return call.args[0], true + } + + minimum_lit, minimum_value, minimum_ok := get_basic_lit_value(call.args[1]) + if !minimum_ok { + return call.args[1], true + } + + maximum_lit, maximum_value, maximum_ok := get_basic_lit_value(call.args[2]) + if !maximum_ok { + return call.args[2], true + } + + if value_value < minimum_value { + return convert_candidate(minimum_lit, is_mutable), true + } + if value_value > maximum_value { + return convert_candidate(maximum_lit, is_mutable), true + } + + return convert_candidate(value_lit, is_mutable), true + } + } + + } + + return nil, false +} + check_node_recursion :: proc(ast_context: ^AstContext, node: ^ast.Node) -> bool { raw := cast(rawptr)node @@ -1621,13 +1772,12 @@ internal_resolve_type_identifier :: proc(ast_context: ^AstContext, node: ast.Ide } if return_symbol, ok = internal_resolve_type_expression(ast_context, v.expr); ok { - if proc_value, ok := return_symbol.value.(SymbolProcedureValue); ok { - if len(proc_value.return_types) >= 1 && proc_value.return_types[0].type != nil { - return_symbol, ok = internal_resolve_type_expression( - ast_context, - proc_value.return_types[0].type, - ) - } + return_types := get_proc_return_types(ast_context, return_symbol, v, global.mutable); + if len(return_types) > 0 { + return_symbol, ok = internal_resolve_type_expression( + ast_context, + return_types[0], + ) } // Otherwise should be a parapoly style } @@ -3189,6 +3339,7 @@ get_generic_assignment :: proc( results: ^[dynamic]^ast.Expr, calls: ^map[int]bool, flags: GetGenericAssignmentFlags, + is_mutable: bool, ) { using ast @@ -3196,11 +3347,11 @@ get_generic_assignment :: proc( #partial switch v in value.derived { case ^Or_Return_Expr: - get_generic_assignment(file, v.expr, ast_context, results, calls, flags) + get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable) case ^Or_Else_Expr: - get_generic_assignment(file, v.x, ast_context, results, calls, flags) + get_generic_assignment(file, v.x, ast_context, results, calls, flags, is_mutable) case ^Or_Branch_Expr: - get_generic_assignment(file, v.expr, ast_context, results, calls, flags) + get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable) case ^Call_Expr: old_call := ast_context.call ast_context.call = cast(^ast.Call_Expr)value @@ -3227,14 +3378,10 @@ get_generic_assignment :: proc( if symbol, ok := resolve_type_expression(ast_context, v.expr); ok { #partial switch symbol_value in symbol.value { case SymbolProcedureValue: - for ret in symbol_value.return_types { - if ret.type != nil { - calls[len(results)] = true - append(results, ret.type) - } else if ret.default_value != nil { - calls[len(results)] = true - append(results, ret.default_value) - } + return_types := get_proc_return_types(ast_context, symbol, v, is_mutable) + for ret in return_types { + calls[len(results)] = true + append(results, ret) } case SymbolAggregateValue: //In case we can't resolve the proc group, just save it anyway, so it won't cause any issues further down the line. @@ -3242,14 +3389,14 @@ get_generic_assignment :: proc( case SymbolStructValue: // Parametrized struct - get_generic_assignment(file, v.expr, ast_context, results, calls, flags) + get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable) case SymbolUnionValue: // Parametrized union - get_generic_assignment(file, v.expr, ast_context, results, calls, flags) + get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable) case: if ident, ok := v.expr.derived.(^ast.Ident); ok { - //TODO: Simple assumption that you are casting it the type. + //TODO: Simple assumption that you are casting it the type. type_ident := new_type(Ident, ident.pos, ident.end, ast_context.allocator) type_ident.name = ident.name append(results, type_ident) @@ -3347,7 +3494,7 @@ get_locals_value_decl :: proc(file: ast.File, value_decl: ast.Value_Decl, ast_co } for value in value_decl.values { - get_generic_assignment(file, value, ast_context, &results, &calls, flags) + get_generic_assignment(file, value, ast_context, &results, &calls, flags, value_decl.is_mutable) } if len(results) == 0 { @@ -3534,7 +3681,7 @@ get_locals_assign_stmt :: proc(file: ast.File, stmt: ast.Assign_Stmt, ast_contex calls := make(map[int]bool, 0, context.temp_allocator) for rhs in stmt.rhs { - get_generic_assignment(file, rhs, ast_context, &results, &calls, {}) + get_generic_assignment(file, rhs, ast_context, &results, &calls, {}, true) } if len(stmt.lhs) != len(results) { |