aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDanielGavin <danielgavin5@hotmail.com>2025-07-25 13:22:13 +0200
committerGitHub <noreply@github.com>2025-07-25 13:22:13 +0200
commit812a065324e8d8ea75fcab4bf881d8574306e1b2 (patch)
tree11ca9d3d470c10559b34fcf6fac39feb3d5916f2 /src
parent0a5488bb05dbc169f1b7761fd6872d3b4aba4422 (diff)
parent720938e13e8ef433eecae552f471c91e2253044b (diff)
Merge pull request #780 from BradLewis/feat/resolve-builtin-procs
Resolve builtin max, min, clamp and abs return types correctly
Diffstat (limited to 'src')
-rw-r--r--src/server/analysis.odin193
1 files changed, 170 insertions, 23 deletions
diff --git a/src/server/analysis.odin b/src/server/analysis.odin
index 4b0134a..91fa2c1 100644
--- a/src/server/analysis.odin
+++ b/src/server/analysis.odin
@@ -934,6 +934,157 @@ resolve_basic_directive :: proc(
return {}, false
}
+// Gets the return type of the proc.
+// Requires the underlying call expression to handle some builtin procs
+get_proc_return_types :: proc(
+ ast_context: ^AstContext, symbol: Symbol, call: ^ast.Call_Expr, is_mutable: bool,
+) -> []^ast.Expr {
+ return_types := make([dynamic]^ast.Expr, context.temp_allocator)
+ if ret, ok := check_builtin_proc_return_type(symbol, call, is_mutable); ok {
+ if call, ok := ret.derived.(^ast.Call_Expr); ok {
+ if symbol, ok := internal_resolve_type_expression(ast_context, call.expr); ok {
+ return get_proc_return_types(ast_context, symbol, call, true)
+ }
+ }
+ append(&return_types, ret)
+ } else if v, ok := symbol.value.(SymbolProcedureValue); ok {
+ for ret in v.return_types {
+ if ret.type != nil {
+ append(&return_types, ret.type)
+ } else if ret.default_value != nil {
+ append(&return_types, ret.default_value)
+ }
+ }
+ }
+
+ return return_types[:]
+}
+
+// Attempts to resolve the type of the builtin proc by following the rules of the odin type checker
+// defined in `check_builtin.cpp`.
+// We don't need to worry about whether the inputs to the procs are valid which eliminates most edge cases.
+// The basic rules are as follows:
+// - For values not known at compile time (eg values return from procs), just return that type.
+// The correct value will either be that type or a compiler error.
+// - If all values are known at compile time, then we essentially compute the relevant value
+// and return that type.
+// There is a difference in the returned types between constants and variables. Constants will use an untyped
+// value whereas variables will be typed (as either `int` or `f64`).
+check_builtin_proc_return_type :: proc(symbol: Symbol, call: ^ast.Call_Expr, is_mutable: bool) -> (^ast.Expr, bool) {
+ convert_candidate :: proc(candidate: ^ast.Basic_Lit, is_mutable: bool) -> ^ast.Expr {
+ if is_mutable {
+ ident := ast.new(ast.Ident, candidate.pos, candidate.end)
+ if candidate.tok.kind == .Integer {
+ ident.name = "int"
+ } else {
+ ident.name = "f64"
+ }
+ return ident
+ }
+
+ return candidate
+ }
+
+ compare_basic_lit_value :: proc(a, b: f64, name: string) -> bool {
+ if name == "max" {
+ return a > b
+ } else if name == "min" {
+ return a < b
+ }
+ return a > b
+ }
+
+ get_basic_lit_value :: proc(n: ^ast.Expr) -> (^ast.Basic_Lit, f64, bool) {
+ n := n
+
+ op := ""
+ if u, ok := n.derived.(^ast.Unary_Expr); ok {
+ op = u.op.text
+ n = u.expr
+ }
+
+ if lit, ok := n.derived.(^ast.Basic_Lit); ok {
+ text := lit.tok.text
+ if op != "" {
+ text = fmt.tprintf("%s%s", op, text)
+ }
+ value, ok := strconv.parse_f64(text)
+ if !ok {
+ return nil, 0, false
+ }
+
+ return lit, value, true
+ }
+
+ return nil, 0, false
+ }
+
+ if symbol.pkg == "$builtin" {
+ switch symbol.name {
+ case "max", "min":
+ curr_candidate: ^ast.Basic_Lit
+ curr_value := 0.0
+ for arg, i in call.args {
+ if lit, value, ok := get_basic_lit_value(arg); ok {
+ if i != 0 {
+ if compare_basic_lit_value(value, curr_value, symbol.name) {
+ curr_candidate = lit
+ curr_value = value
+ }
+ } else {
+ curr_candidate = lit
+ curr_value = value
+ }
+ }
+ if lit, ok := arg.derived.(^ast.Basic_Lit); ok {
+ } else {
+ return arg, true
+ }
+ }
+ if curr_candidate != nil {
+ return convert_candidate(curr_candidate, is_mutable), true
+ }
+ case "abs":
+ for arg in call.args {
+ if lit, _, ok := get_basic_lit_value(arg); ok {
+ return convert_candidate(lit, is_mutable), true
+ }
+ return arg, true
+ }
+ case "clamp":
+ if len(call.args) == 3 {
+
+ value_lit, value_value, value_ok := get_basic_lit_value(call.args[0])
+ if !value_ok {
+ return call.args[0], true
+ }
+
+ minimum_lit, minimum_value, minimum_ok := get_basic_lit_value(call.args[1])
+ if !minimum_ok {
+ return call.args[1], true
+ }
+
+ maximum_lit, maximum_value, maximum_ok := get_basic_lit_value(call.args[2])
+ if !maximum_ok {
+ return call.args[2], true
+ }
+
+ if value_value < minimum_value {
+ return convert_candidate(minimum_lit, is_mutable), true
+ }
+ if value_value > maximum_value {
+ return convert_candidate(maximum_lit, is_mutable), true
+ }
+
+ return convert_candidate(value_lit, is_mutable), true
+ }
+ }
+
+ }
+
+ return nil, false
+}
+
check_node_recursion :: proc(ast_context: ^AstContext, node: ^ast.Node) -> bool {
raw := cast(rawptr)node
@@ -1621,13 +1772,12 @@ internal_resolve_type_identifier :: proc(ast_context: ^AstContext, node: ast.Ide
}
if return_symbol, ok = internal_resolve_type_expression(ast_context, v.expr); ok {
- if proc_value, ok := return_symbol.value.(SymbolProcedureValue); ok {
- if len(proc_value.return_types) >= 1 && proc_value.return_types[0].type != nil {
- return_symbol, ok = internal_resolve_type_expression(
- ast_context,
- proc_value.return_types[0].type,
- )
- }
+ return_types := get_proc_return_types(ast_context, return_symbol, v, global.mutable);
+ if len(return_types) > 0 {
+ return_symbol, ok = internal_resolve_type_expression(
+ ast_context,
+ return_types[0],
+ )
}
// Otherwise should be a parapoly style
}
@@ -3189,6 +3339,7 @@ get_generic_assignment :: proc(
results: ^[dynamic]^ast.Expr,
calls: ^map[int]bool,
flags: GetGenericAssignmentFlags,
+ is_mutable: bool,
) {
using ast
@@ -3196,11 +3347,11 @@ get_generic_assignment :: proc(
#partial switch v in value.derived {
case ^Or_Return_Expr:
- get_generic_assignment(file, v.expr, ast_context, results, calls, flags)
+ get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable)
case ^Or_Else_Expr:
- get_generic_assignment(file, v.x, ast_context, results, calls, flags)
+ get_generic_assignment(file, v.x, ast_context, results, calls, flags, is_mutable)
case ^Or_Branch_Expr:
- get_generic_assignment(file, v.expr, ast_context, results, calls, flags)
+ get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable)
case ^Call_Expr:
old_call := ast_context.call
ast_context.call = cast(^ast.Call_Expr)value
@@ -3227,14 +3378,10 @@ get_generic_assignment :: proc(
if symbol, ok := resolve_type_expression(ast_context, v.expr); ok {
#partial switch symbol_value in symbol.value {
case SymbolProcedureValue:
- for ret in symbol_value.return_types {
- if ret.type != nil {
- calls[len(results)] = true
- append(results, ret.type)
- } else if ret.default_value != nil {
- calls[len(results)] = true
- append(results, ret.default_value)
- }
+ return_types := get_proc_return_types(ast_context, symbol, v, is_mutable)
+ for ret in return_types {
+ calls[len(results)] = true
+ append(results, ret)
}
case SymbolAggregateValue:
//In case we can't resolve the proc group, just save it anyway, so it won't cause any issues further down the line.
@@ -3242,14 +3389,14 @@ get_generic_assignment :: proc(
case SymbolStructValue:
// Parametrized struct
- get_generic_assignment(file, v.expr, ast_context, results, calls, flags)
+ get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable)
case SymbolUnionValue:
// Parametrized union
- get_generic_assignment(file, v.expr, ast_context, results, calls, flags)
+ get_generic_assignment(file, v.expr, ast_context, results, calls, flags, is_mutable)
case:
if ident, ok := v.expr.derived.(^ast.Ident); ok {
- //TODO: Simple assumption that you are casting it the type.
+ //TODO: Simple assumption that you are casting it the type.
type_ident := new_type(Ident, ident.pos, ident.end, ast_context.allocator)
type_ident.name = ident.name
append(results, type_ident)
@@ -3347,7 +3494,7 @@ get_locals_value_decl :: proc(file: ast.File, value_decl: ast.Value_Decl, ast_co
}
for value in value_decl.values {
- get_generic_assignment(file, value, ast_context, &results, &calls, flags)
+ get_generic_assignment(file, value, ast_context, &results, &calls, flags, value_decl.is_mutable)
}
if len(results) == 0 {
@@ -3534,7 +3681,7 @@ get_locals_assign_stmt :: proc(file: ast.File, stmt: ast.Assign_Stmt, ast_contex
calls := make(map[int]bool, 0, context.temp_allocator)
for rhs in stmt.rhs {
- get_generic_assignment(file, rhs, ast_context, &results, &calls, {})
+ get_generic_assignment(file, rhs, ast_context, &results, &calls, {}, true)
}
if len(stmt.lhs) != len(results) {