package server import "core:odin/ast" import "core:odin/tokenizer" import "core:reflect" import "core:strings" import "src:common" resolve_poly :: proc( ast_context: ^AstContext, call_node: ^ast.Expr, call_symbol: Symbol, poly_node: ^ast.Expr, poly_map: ^map[string]^ast.Expr, ) -> bool { if poly_node == nil || call_node == nil { return false } specialization: ^ast.Expr type: ^ast.Expr poly_node := poly_node poly_node, _, _ = unwrap_pointer_expr(poly_node) #partial switch v in poly_node.derived { case ^ast.Typeid_Type: specialization = v.specialization case ^ast.Poly_Type: specialization = v.specialization type = v.type case: specialization = poly_node } if specialization == nil { if type != nil { if ident, ok := unwrap_ident(type); ok { if untyped_value, ok := call_symbol.value.(SymbolUntypedValue); ok { save_poly_map(ident, symbol_to_expr(call_symbol, call_node.pos.file), poly_map) } else { save_poly_map( ident, make_ident_ast(ast_context, call_node.pos, call_node.end, call_symbol.name), poly_map, ) } } } return true } else if type != nil { if ident, ok := unwrap_ident(type); ok { call_node_id := reflect.union_variant_typeid(call_node.derived) specialization_id := reflect.union_variant_typeid(specialization.derived) if ast_context.position_hint == .TypeDefinition && call_node_id == specialization_id { // TODO: Fix this so it doesn't need to be aware that we're in a type definition // if the specialization type matches the type of the parameter passed to the proc // we store that rather than the specialization so we can follow it correctly // for things like `textDocument/typeDefinition` save_poly_map( ident, make_ident_ast(ast_context, call_node.pos, call_node.end, call_symbol.name), poly_map, ) } else { save_poly_map(ident, specialization, poly_map) } } } #partial switch p in specialization.derived { case ^ast.Matrix_Type: if call_matrix, ok := call_node.derived.(^ast.Matrix_Type); ok { found := false if poly_type, ok := p.row_count.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_matrix.row_count, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_matrix.row_count, call_symbol, p.row_count, poly_map) } found |= true } if poly_type, ok := p.column_count.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_matrix.column_count, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_matrix.column_count, call_symbol, p.column_count, poly_map) } found |= true } if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_matrix.elem, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_matrix.elem, call_symbol, p.elem, poly_map) } found |= true } return found } case ^ast.Call_Expr: if call_struct, ok := call_node.derived.(^ast.Struct_Type); ok { arg_index := 0 struct_value := call_symbol.value.(SymbolStructValue) found := false for arg in p.args { if poly_type, ok := arg.derived.(^ast.Poly_Type); ok { if poly_type.type == nil || len(struct_value.args) <= arg_index { return false } save_poly_map(poly_type.type, struct_value.args[arg_index], poly_map) arg_index += 1 found |= true } } return found } case ^ast.Dynamic_Array_Type: if call_array, ok := call_node.derived.(^ast.Dynamic_Array_Type); ok { if dynamic_array_is_soa(p^) != dynamic_array_is_soa(call_array^) { return false } //It's not enough for them to both arrays, they also have to share soa attributes if p.tag != nil && call_array.tag != nil { a, ok1 := p.tag.derived.(^ast.Basic_Directive) b, ok2 := call_array.tag.derived.(^ast.Basic_Directive) if ok1 && ok2 && (a.name == "soa" || b.name == "soa") && a.name != b.name { return false } } if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_array.elem, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_array.elem, call_symbol, p.elem, poly_map) } return true } } case ^ast.Array_Type: if call_array, ok := call_node.derived.(^ast.Array_Type); ok { found := false if array_is_soa(p^) != array_is_soa(call_array^) { return false } //It's not enough for them to both arrays, they also have to share soa attributes if p.tag != nil && call_array.tag != nil { a, ok1 := p.tag.derived.(^ast.Basic_Directive) b, ok2 := call_array.tag.derived.(^ast.Basic_Directive) if ok1 && ok2 && (a.name == "soa" || b.name == "soa") && a.name != b.name { return false } } if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_array.elem, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_array.elem, call_symbol, p.elem, poly_map) } found |= true } if p.len != nil { if poly_type, ok := p.len.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_array.len, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_array.len, call_symbol, p.len, poly_map) } found |= true } } return found } case ^ast.Ellipsis: if call_array, ok := call_node.derived.(^ast.Array_Type); ok { found := false if array_is_soa(call_array^) { return false } if poly_type, ok := p.expr.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_array.elem, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_array.elem, call_symbol, p.expr, poly_map) } found |= true } return found } case ^ast.Map_Type: if call_map, ok := call_node.derived.(^ast.Map_Type); ok { found := false if poly_type, ok := p.key.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_map.key, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_map.key, call_symbol, p.key, poly_map) } found |= true } if poly_type, ok := p.value.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_map.value, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_map.value, call_symbol, p.value, poly_map) } found |= true } return found } case ^ast.Multi_Pointer_Type: if call_pointer, ok := call_node.derived.(^ast.Multi_Pointer_Type); ok { if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_pointer.elem, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_pointer.elem, call_symbol, p.elem, poly_map) } return true } } case ^ast.Pointer_Type: if call_pointer, ok := call_node.derived.(^ast.Pointer_Type); ok { if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, call_pointer.elem, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, call_pointer.elem, call_symbol, p.elem, poly_map) } return true } } case ^ast.Comp_Lit: if comp_lit, ok := call_node.derived.(^ast.Comp_Lit); ok { if poly_type, ok := p.type.derived.(^ast.Poly_Type); ok { if ident, ok := unwrap_ident(poly_type.type); ok { save_poly_map(ident, comp_lit.type, poly_map) } if poly_type.specialization != nil { return resolve_poly(ast_context, comp_lit.type, call_symbol, p.type, poly_map) } return true } } case ^ast.Struct_Type, ^ast.Proc_Type: case ^ast.Ident: return true case: return false } return false } is_generic_type_recursive :: proc(expr: ^ast.Expr, name: string) -> bool { Data :: struct { name: string, exists: bool, } visit_function :: proc(visitor: ^ast.Visitor, node: ^ast.Node) -> ^ast.Visitor { if node == nil { return nil } data := cast(^Data)visitor.data if ident, ok := node.derived.(^ast.Ident); ok { if ident.name == data.name { data.exists = true return nil } } return visitor } data := Data { name = name, } visitor := ast.Visitor { data = &data, visit = visit_function, } ast.walk(&visitor, expr) return data.exists } save_poly_map :: proc(ident: ^ast.Ident, expr: ^ast.Expr, poly_map: ^map[string]^ast.Expr) { if ident == nil || expr == nil { return } poly_map[ident.name] = expr } get_poly_map :: proc(node: ^ast.Node, poly_map: ^map[string]^ast.Expr) -> (^ast.Expr, bool) { if node == nil { return {}, false } if ident, ok := node.derived.(^ast.Ident); ok { if v, ok := poly_map[ident.name]; ok && !is_generic_type_recursive(v, ident.name) { return v, ok } } if poly, ok := node.derived.(^ast.Poly_Type); ok && poly.type != nil { if v, ok := poly_map[poly.type.name]; ok && !is_generic_type_recursive(v, poly.type.name) { return v, ok } } return nil, false } find_and_replace_poly_type :: proc(expr: ^ast.Expr, poly_map: ^map[string]^ast.Expr) { visit_function :: proc(visitor: ^ast.Visitor, node: ^ast.Node) -> ^ast.Visitor { if node == nil { return nil } poly_map := cast(^map[string]^ast.Expr)visitor.data #partial switch v in node.derived { case ^ast.Comp_Lit: if expr, ok := get_poly_map(v.type, poly_map); ok { v.type = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Matrix_Type: if expr, ok := get_poly_map(v.elem, poly_map); ok { v.elem = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } if expr, ok := get_poly_map(v.column_count, poly_map); ok { v.column_count = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } if expr, ok := get_poly_map(v.row_count, poly_map); ok { v.row_count = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Dynamic_Array_Type: if expr, ok := get_poly_map(v.elem, poly_map); ok { v.elem = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Array_Type: if expr, ok := get_poly_map(v.elem, poly_map); ok { v.elem = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } if expr, ok := get_poly_map(v.len, poly_map); ok { v.len = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Multi_Pointer_Type: if expr, ok := get_poly_map(v.elem, poly_map); ok { v.elem = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Pointer_Type: if expr, ok := get_poly_map(v.elem, poly_map); ok { v.elem = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Proc_Type: if v.params != nil { for param in v.params.list { if expr, ok := get_poly_map(param.type, poly_map); ok { param.type = expr param.pos.file = expr.pos.file param.end.file = expr.end.file } } } if v.results != nil { for result in v.results.list { if expr, ok := get_poly_map(result.type, poly_map); ok { result.type = expr result.pos.file = expr.pos.file result.end.file = expr.end.file } } } case ^ast.Ellipsis: if expr, ok := get_poly_map(v.expr, poly_map); ok { v.expr = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Map_Type: if expr, ok := get_poly_map(v.key, poly_map); ok { v.key = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } if expr, ok := get_poly_map(v.value, poly_map); ok { v.value = expr v.pos.file = expr.pos.file v.end.file = expr.end.file } case ^ast.Call_Expr: for &arg in v.args { if expr, ok := get_poly_map(arg, poly_map); ok { arg = expr } } } return visitor } visitor := ast.Visitor { data = poly_map, visit = visit_function, } ast.walk(&visitor, expr) } resolve_generic_function :: proc { resolve_generic_function_ast, resolve_generic_function_symbol, } resolve_generic_function_ast :: proc( ast_context: ^AstContext, proc_lit: ast.Proc_Lit, proc_symbol: Symbol, ) -> ( Symbol, bool, ) { if ast_context.call == nil { return Symbol{}, false } params: []^ast.Field if proc_lit.type.params != nil { params = proc_lit.type.params.list } results: []^ast.Field if proc_lit.type.results != nil { results = proc_lit.type.results.list } range := common.get_token_range(proc_lit, ast_context.file.src) uri := common.create_uri(proc_lit.pos.file, ast_context.allocator).uri return resolve_generic_function_symbol(ast_context, params, results, proc_lit.inlining, proc_symbol) } resolve_generic_function_symbol :: proc( ast_context: ^AstContext, params: []^ast.Field, results: []^ast.Field, inlining: ast.Proc_Inlining, proc_symbol: Symbol, ) -> ( Symbol, bool, ) { if ast_context.call == nil { return {}, false } call_expr := ast_context.call poly_map := make(map[string]^ast.Expr, 0, context.temp_allocator) i := 0 count_required_params := 0 for param in params { if param.default_value == nil { count_required_params += 1 } for name in param.names { if len(call_expr.args) <= i { break } if param.type == nil { continue } reset_ast_context(ast_context) ast_context.current_package = ast_context.document_package if symbol, ok := resolve_type_expression(ast_context, call_expr.args[i]); ok { if ident, ok := call_expr.args[i].derived.(^ast.Ident); ok && symbol.name == "" { symbol.name = ident.name } file := strings.trim_prefix(symbol.uri, "file://") if file == "" { file = call_expr.args[i].pos.file } symbol_expr := symbol_to_expr(symbol, file, context.temp_allocator) if symbol_expr == nil { return {}, false } //If we have a function call, we should instead look at the return value: bar(foo(123)) if symbol_value, ok := symbol.value.(SymbolProcedureValue); ok && len(symbol_value.return_types) > 0 { if _, ok := call_expr.args[i].derived.(^ast.Call_Expr); ok { if symbol_value.return_types[0].type != nil { if symbol, ok = resolve_type_expression(ast_context, symbol_value.return_types[0].type); ok { symbol_expr = symbol_to_expr( symbol, call_expr.args[i].pos.file, context.temp_allocator, ) if symbol_expr == nil { return {}, false } } } } } // We set the offset so we can find it as a local if it's based on the type of a local var symbol_expr.pos.offset = call_expr.pos.offset symbol_expr.end.offset = call_expr.end.offset symbol_expr = clone_expr(symbol_expr, ast_context.allocator, nil) param_type := clone_expr(param.type, ast_context.allocator, nil) if resolve_poly(ast_context, symbol_expr, symbol, param_type, &poly_map) { if poly, ok := name.derived.(^ast.Poly_Type); ok { poly_map[poly.type.name] = clone_expr(call_expr.args[i], ast_context.allocator, nil) } } } i += 1 } } for k, v in poly_map { find_and_replace_poly_type(v, &poly_map) } if count_required_params > len(call_expr.args) || count_required_params == 0 || len(call_expr.args) == 0 { return {}, false } function_name := "" if ident, ok := call_expr.expr.derived.(^ast.Ident); ok { function_name = ident.name } else if selector, ok := call_expr.expr.derived.(^ast.Selector_Expr); ok { function_name = selector.field.name } else { return {}, false } return_types := make([dynamic]^ast.Field, ast_context.allocator) argument_types := make([dynamic]^ast.Field, ast_context.allocator) for result in results { if result.type == nil { continue } field := cast(^ast.Field)clone_node(result, ast_context.allocator, nil) if ident, ok := unwrap_ident(field.type); ok { if expr, ok := poly_map[ident.name]; ok { field.type = expr } } find_and_replace_poly_type(field.type, &poly_map) append(&return_types, field) } for param in params { field := cast(^ast.Field)clone_node(param, ast_context.allocator, nil) if field.type != nil { if poly_type, ok := field.type.derived.(^ast.Poly_Type); ok { if expr, ok := poly_map[poly_type.type.name]; ok { field.type = expr } } else { if ident, ok := unwrap_ident(field.type); ok { if expr, ok := poly_map[ident.name]; ok { field.type = expr } } find_and_replace_poly_type(field.type, &poly_map) } } if len(param.names) > 0 { if poly_type, ok := param.names[0].derived.(^ast.Poly_Type); ok && param.type != nil { if m, ok := poly_map[poly_type.type.name]; ok { field.type = m } } } append(&argument_types, field) } symbol := proc_symbol symbol.value = SymbolProcedureValue { return_types = return_types[:], arg_types = argument_types[:], orig_arg_types = params[:], orig_return_types = results[:], inlining = inlining, } return symbol, true } is_procedure_generic :: proc(proc_type: ^ast.Proc_Type) -> bool { if proc_type.generic { return true } for param in proc_type.params.list { if param.type == nil { continue } if expr_contains_poly(param.type) { return true } } return false } resolve_poly_struct :: proc(ast_context: ^AstContext, b: ^SymbolStructValueBuilder, poly_params: ^ast.Field_List) { if ast_context.call == nil { return } i := 0 poly_map := make(map[string]^ast.Expr, 0, context.temp_allocator) clear(&b.poly_names) for param in poly_params.list { for name in param.names { append(&b.poly_names, node_to_string(name)) if len(ast_context.call.args) <= i { break } if param.type == nil { continue } if ident, ok := param.type.derived.(^ast.Ident); ok { poly_map[ident.name] = ast_context.call.args[i] b.poly_names[i] = node_to_string(ast_context.call.args[i]) } else if poly, ok := param.type.derived.(^ast.Typeid_Type); ok { if ident, ok := name.derived.(^ast.Ident); ok { poly_map[ident.name] = ast_context.call.args[i] b.poly_names[i] = node_to_string(ast_context.call.args[i]) } else if poly, ok := name.derived.(^ast.Poly_Type); ok { if poly.type != nil { b.poly_names[i] = node_to_string(ast_context.call.args[i]) poly_map[poly.type.name] = ast_context.call.args[i] } } } append(&b.args, ast_context.call.args[i]) i += 1 } } Visit_Data :: struct { poly_map: map[string]^ast.Expr, symbol_value_builder: ^SymbolStructValueBuilder, parent: ^ast.Node, parent_proc: ^ast.Proc_Type, i: int, poly_index: int, } visit :: proc(visitor: ^ast.Visitor, node: ^ast.Node) -> ^ast.Visitor { if node == nil || visitor == nil { return nil } data := cast(^Visit_Data)visitor.data if ident, ok := node.derived.(^ast.Ident); ok { if expr, ok := data.poly_map[ident.name]; ok { if data.parent_proc != nil { // If the field is a parapoly procedure, we check to see if any of the params or return types // need to be updated if data.parent_proc.params != nil { for ¶m in data.parent_proc.params.list { type := param.type if type == nil { type = param.default_value } if type != nil { if param_ident, ok := type.derived.(^ast.Ident); ok { if param_ident.name == ident.name { param.type = expr } } } } } if data.parent_proc.results != nil { for &return_value in data.parent_proc.results.list { if return_ident, ok := return_value.type.derived.(^ast.Ident); ok { if return_ident.name == ident.name { return_value.type = expr } } } } } if data.parent != nil { #partial switch &v in data.parent.derived { case ^ast.Array_Type: v.elem = expr case ^ast.Dynamic_Array_Type: v.elem = expr case ^ast.Pointer_Type: v.elem = expr } } else if data.parent_proc == nil { data.symbol_value_builder.types[data.i] = expr data.poly_index += 1 } } } #partial switch v in node.derived { case ^ast.Array_Type, ^ast.Dynamic_Array_Type, ^ast.Selector_Expr, ^ast.Pointer_Type: data.parent = node case ^ast.Proc_Type: data.parent_proc = v } return visitor } for type, i in b.types { data := Visit_Data { poly_map = poly_map, symbol_value_builder = b, i = i, } visitor := ast.Visitor { data = &data, visit = visit, } ast.walk(&visitor, type) } } resolve_poly_union :: proc(ast_context: ^AstContext, poly_params: ^ast.Field_List, symbol: ^Symbol) { if ast_context.call == nil { return } symbol_value := &symbol.value.(SymbolUnionValue) if symbol_value == nil { return } i := 0 poly_map := make(map[string]^ast.Expr, 0, context.temp_allocator) poly_names := make([dynamic]string, 0, context.temp_allocator) for param in poly_params.list { for name in param.names { append(&poly_names, node_to_string(name)) if len(ast_context.call.args) <= i { break } if param.type == nil { continue } if poly, ok := param.type.derived.(^ast.Typeid_Type); ok { if ident, ok := name.derived.(^ast.Ident); ok { poly_map[ident.name] = ast_context.call.args[i] poly_names[i] = node_to_string(ast_context.call.args[i]) } else if poly, ok := name.derived.(^ast.Poly_Type); ok { if poly.type != nil { poly_map[poly.type.name] = ast_context.call.args[i] poly_names[i] = node_to_string(ast_context.call.args[i]) } } } i += 1 } } for type, i in symbol_value.types { if ident, ok := type.derived.(^ast.Ident); ok { if expr, ok := poly_map[ident.name]; ok { symbol_value.types[i] = expr } } else if call_expr, ok := type.derived.(^ast.Call_Expr); ok { if call_expr.args == nil { continue } for arg, i in call_expr.args { if ident, ok := arg.derived.(^ast.Ident); ok { if expr, ok := poly_map[ident.name]; ok { symbol_value.types[i] = expr } } } } } symbol_value.poly_names = poly_names[:] }