aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDanielGavin <danielgavin5@hotmail.com>2023-12-29 17:43:10 +0100
committerDanielGavin <danielgavin5@hotmail.com>2023-12-29 17:43:10 +0100
commite93830d576ea387cc7dc635306ce07f43c515853 (patch)
treeb810eadd2dfdbaa3a72f90e9db81ade7ca36ac06 /src
parent589f609f76040c3be20adfa799d8b79345477882 (diff)
More poly work
Diffstat (limited to 'src')
-rw-r--r--src/server/generics.odin93
-rw-r--r--src/server/symbol.odin7
2 files changed, 93 insertions, 7 deletions
diff --git a/src/server/generics.odin b/src/server/generics.odin
index 5d07181..7141a75 100644
--- a/src/server/generics.odin
+++ b/src/server/generics.odin
@@ -57,6 +57,61 @@ resolve_poly :: proc(
}
#partial switch p in specialization.derived {
+ case ^ast.Matrix_Type:
+ if call_matrix, ok := call_node.derived.(^ast.Matrix_Type); ok {
+ found := false
+ if poly_type, ok := p.row_count.derived.(^ast.Poly_Type); ok {
+ if ident, ok := unwrap_ident(poly_type.type); ok {
+ poly_map[ident.name] = call_matrix.row_count
+ }
+
+ if poly_type.specialization != nil {
+ return resolve_poly(
+ ast_context,
+ call_matrix.row_count,
+ call_symbol,
+ p.row_count,
+ poly_map,
+ )
+ }
+ found |= true
+ }
+
+ if poly_type, ok := p.column_count.derived.(^ast.Poly_Type); ok {
+ if ident, ok := unwrap_ident(poly_type.type); ok {
+ poly_map[ident.name] = call_matrix.column_count
+ }
+
+ if poly_type.specialization != nil {
+ return resolve_poly(
+ ast_context,
+ call_matrix.column_count,
+ call_symbol,
+ p.column_count,
+ poly_map,
+ )
+ }
+ found |= true
+ }
+
+ if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok {
+ if ident, ok := unwrap_ident(poly_type.type); ok {
+ poly_map[ident.name] = call_matrix.elem
+ }
+
+ if poly_type.specialization != nil {
+ return resolve_poly(
+ ast_context,
+ call_matrix.elem,
+ call_symbol,
+ p.elem,
+ poly_map,
+ )
+ }
+ found |= true
+ }
+ return found
+ }
case ^ast.Call_Expr:
if call_struct, ok := call_node.derived.(^ast.Struct_Type); ok {
arg_index := 0
@@ -78,7 +133,6 @@ resolve_poly :: proc(
}
}
case ^ast.Struct_Type:
-
case ^ast.Dynamic_Array_Type:
if call_array, ok := call_node.derived.(^ast.Dynamic_Array_Type); ok {
if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok {
@@ -100,6 +154,7 @@ resolve_poly :: proc(
}
case ^ast.Array_Type:
if call_array, ok := call_node.derived.(^ast.Array_Type); ok {
+ found := false
if poly_type, ok := p.elem.derived.(^ast.Poly_Type); ok {
if ident, ok := unwrap_ident(poly_type.type); ok {
poly_map[ident.name] = call_array.elem
@@ -114,11 +169,31 @@ resolve_poly :: proc(
poly_map,
)
}
- return true
+ found |= true
}
+
+ if poly_type, ok := p.len.derived.(^ast.Poly_Type); ok {
+ if ident, ok := unwrap_ident(poly_type.type); ok {
+ poly_map[ident.name] = call_array.len
+ }
+
+ if poly_type.specialization != nil {
+ return resolve_poly(
+ ast_context,
+ call_array.len,
+ call_symbol,
+ p.len,
+ poly_map,
+ )
+ }
+ found |= true
+ }
+
+ return found
}
case ^ast.Map_Type:
if call_map, ok := call_node.derived.(^ast.Map_Type); ok {
+ found := false
if poly_type, ok := p.key.derived.(^ast.Poly_Type); ok {
if ident, ok := unwrap_ident(poly_type.type); ok {
poly_map[ident.name] = call_map.key
@@ -133,8 +208,10 @@ resolve_poly :: proc(
poly_map,
)
}
- return true
- } else if poly_type, ok := p.value.derived.(^ast.Poly_Type); ok {
+ found |= true
+ }
+
+ if poly_type, ok := p.value.derived.(^ast.Poly_Type); ok {
if ident, ok := unwrap_ident(poly_type.type); ok {
poly_map[ident.name] = call_map.value
}
@@ -148,8 +225,9 @@ resolve_poly :: proc(
poly_map,
)
}
- return true
+ found |= true
}
+ return found
}
case ^ast.Multi_Pointer_Type:
if call_pointer, ok := call_node.derived.(^ast.Multi_Pointer_Type);
@@ -224,7 +302,8 @@ find_and_replace_poly_type :: proc(
case ^ast.Array_Type:
if expr, ok := is_in_poly_map(v.elem, poly_map); ok {
v.elem = expr
- } else if expr, ok := is_in_poly_map(v.len, poly_map); ok {
+ }
+ if expr, ok := is_in_poly_map(v.len, poly_map); ok {
v.len = expr
}
case ^ast.Multi_Pointer_Type:
@@ -358,7 +437,7 @@ resolve_generic_function_symbol :: proc(
}
for k, v in poly_map {
- //fmt.println(k, v.derived)
+ //fmt.println(k, v.derived, "\n")
}
if count_required_params > len(call_expr.args) ||
diff --git a/src/server/symbol.odin b/src/server/symbol.odin
index a62c8cf..0f86d5c 100644
--- a/src/server/symbol.odin
+++ b/src/server/symbol.odin
@@ -314,6 +314,7 @@ symbol_to_expr :: proc(
case SymbolFixedArrayValue:
type := new_type(ast.Array_Type, pos, end, allocator)
type.elem = v.expr
+ type.len = v.len
return type
case SymbolMapValue:
type := new_type(ast.Map_Type, pos, end, allocator)
@@ -332,6 +333,12 @@ symbol_to_expr :: proc(
case SymbolUntypedValue:
type := new_type(ast.Basic_Lit, pos, end, allocator)
return type
+ case SymbolMatrixValue:
+ type := new_type(ast.Matrix_Type, pos, end, allocator)
+ type.row_count = v.x
+ type.column_count = v.y
+ type.elem = v.expr
+ return type
case:
log.errorf("Unhandled symbol %v", symbol)
}