1: # 196 "./lpsrc/flx_unify.ipk"
2: open Flx_types
3: open Flx_mtypes1
4: open Flx_mtypes2
5: open Flx_print
6: open Flx_maps
7: open Flx_util
8: open List
9: open Flx_exceptions
10:
11: let unit_t = `BTYP_tuple []
12:
13: let rec dual t =
14: match map_btype dual t with
15: | `BTYP_sum ls ->
16: begin match ls with
17: | [t] -> t
18: | ls -> `BTYP_tuple ls
19: end
20:
21: | `BTYP_tuple ls ->
22: begin match ls with
23: | [] -> `BTYP_void
24: | [t] -> t
25: | ls -> `BTYP_sum ls
26: end
27:
28: | `BTYP_function (a,b) -> `BTYP_function (b,a)
29: | `BTYP_cfunction (a,b) -> `BTYP_cfunction (b,a)
30: | `BTYP_array (a,b) -> `BTYP_array (b,a)
31:
32: | `BTYP_pointer t -> `BTYP_pointer (dual t)
33: | `BTYP_lvalue t -> `BTYP_lvalue (dual t)
34: | `BTYP_lift t -> `BTYP_lift (dual t)
35: | `BTYP_void -> unit_t
36: | `BTYP_unitsum k ->
37: let rec aux ds k = if k = 0 then ds else aux (unit_t::ds) (k-1) in
38: `BTYP_tuple (aux [] k)
39:
40: | `BTYP_typeset ts -> `BTYP_intersect (map dual ts)
41: | `BTYP_intersect ts -> `BTYP_typeset (map dual ts)
42: | `BTYP_record ts -> `BTYP_variant ts
43: | t -> t
44:
45: (* top down check for fix point not under sum or pointer *)
46: let rec check_recursion t = match t with
47: | `BTYP_pointer _
48: | `BTYP_sum _
49: | `BTYP_function _
50: | `BTYP_cfunction _
51: -> ()
52:
53: | `BTYP_fix i
54: -> raise Bad_recursion
55:
56: | x -> iter_btype check_recursion x
57:
58: let term_subst t1 i t2 =
59: let rec s t = match t with
60: | `BTYP_var (j,_) when i = j -> t2
61: | `BTYP_typefun (ps,r,b) ->
62: let ps = map (fun (i,t) -> i,s t) ps in
63: if (mem_assoc i ps) then
64: `BTYP_typefun (ps,r,b)
65: else
66: let b = s b in
67: let r = s r in
68: `BTYP_typefun (ps,r,b)
69:
70: | `BTYP_type_match (tt, pts) ->
71: let tt = s tt in
72: let pts =
73: map (fun ({pattern=p; pattern_vars=vs; assignments=asgs},x as case) ->
74: if IntSet.mem i vs then case else
75: let asgs = map (fun (i,t) -> i, s t) asgs in
76: {pattern= s p; pattern_vars=vs; assignments=asgs}, s x
77: )
78: pts
79: in
80: `BTYP_type_match (tt,pts)
81:
82: | t -> map_btype s t
83: in s t1
84:
85: let list_subst x t1 =
86: fold_left (fun t1 (i,t2) -> term_subst t1 i t2) t1 x
87:
88: let varmap0_subst varmap t =
89: let rec s t = match map_btype s t with
90: | `BTYP_var (i,_) as x ->
91: if Hashtbl.mem varmap i
92: then Hashtbl.find varmap i
93: else x
94: | x -> x
95: in s t
96:
97: let varmap_subst varmap t =
98: let rec s t = match map_btype s t with
99: | `BTYP_var (i,_) as x ->
100: if Hashtbl.mem varmap i
101: then Hashtbl.find varmap i
102: else x
103: | `BTYP_typefun (p,r,b) ->
104: let
105: p = map (fun (name,kind) -> (name, s kind)) p and
106: r = s r and
107: b = s b
108: in
109: `BTYP_typefun (p,r,b)
110: | x -> x
111: in s t
112:
113: (* the type arguments are matched up with the type
114: variables in order so that
115: vs_i -> ts_i
116: where vs_t might be (fred,var j)
117: *)
118: let mk_varmap
119: (vs:(string * int) list)
120: (ts:btypecode_t list)
121: =
122: if length ts <> length vs
123: then
124: failwith
125: (
126: "[mk_varmap] wrong number of type args, expected vs=" ^
127: si (length vs) ^
128: ", got ts=" ^
129: si (length ts) ^
130: "\nvs= " ^ catmap "," (fun (s,i) -> s ^ "<"^si i^">") vs
131: )
132: ;
133: let varmap = Hashtbl.create 97 in
134: iter2
135: (fun (_, varidx) typ -> Hashtbl.add varmap varidx typ)
136: vs ts
137: ;
138: varmap
139:
140: let tsubst
141: (vs:(string * int) list)
142: (ts:btypecode_t list)
143: (t:btypecode_t)
144: =
145: varmap_subst (mk_varmap vs ts) t
146:
147:
148: (* returns the most general unifier (mgu)
149: of a set of type equations as a list
150: of variable assignments i -> t
151: or raises Not_found if there is no solution
152:
153: HOW IT WORKS:
154:
155: We start with some set of type equations
156: t1 = t2
157: t3 = t4 (1)
158: ...
159:
160: in which the LHS and RHS are general terms that
161: may contain type variables.
162:
163: We want to say whether the equations are consistent,
164: and if so, to return a solution of the form
165: of a set of equations:
166:
167: v1 = u1
168: v2 = u2 (2)
169:
170: where v1 .. vn are type variable
171: which do not occur in any of the
172: terms u1 .. un
173:
174: Such a set is a solution if by replacing v1 with u1,
175: v2 with u2 .. vn with un,
176: everywhere they occur in t1, t2 .... tn,
177: the original equations are reduced to
178: equations terms which are structurally equal
179:
180: The technique is to pick one equation,
181: and match up the outermost structure,
182: making new equations out of the pieces in the middle,
183: or failing if the outer structure does not match.
184:
185: We discard the original equation,
186: add the new equations to the set,
187: and then for any variable assignments of form (2)
188: found, we eliminate that variable in the
189: all the other equations by substitution.
190:
191:
192: At the end we are guarrateed to either have found
193: the equations have no solution, or computed one,
194: although it may be that the terms u1 .. u2 ..
195: contain some type variables.
196:
197: There is a caveat though: we may obtain
198: an equation
199:
200: v = t
201:
202: where v occurs in t, that is, a recursive equation.
203: If that happens, we eliminate the occurences
204: of v in t before replacement in other equations:
205: we do this by replacing the RHS occurences of
206: v with a fixpoint operator.
207:
208: *)
209:
210:
211: let var_i_occurs i t =
212: let rec aux t:unit = match t with
213: | `BTYP_var (j,_) when i = j -> raise Not_found
214: | _ -> iter_btype aux t
215: in
216: try
217: aux t;
218: false
219: with Not_found -> true
220:
221: let rec vars_in t =
222: let vs = ref IntSet.empty in
223: let add_var i = vs := IntSet.add i !vs in
224: let rec aux t = match t with
225: | `BTYP_var (i,_) -> add_var i
226: | _ -> iter_btype aux t
227: in aux t; !vs
228:
229: let fix i t =
230: let rec aux n t =
231: let aux t = aux (n - 1) t in
232: match t with
233: | `BTYP_var (k,_) -> if k = i then `BTYP_fix n else t
234: | `BTYP_inst (k,ts) -> `BTYP_inst (k, map aux ts)
235: | `BTYP_tuple ts -> `BTYP_tuple (map aux ts)
236: | `BTYP_sum ts -> `BTYP_sum (map aux ts)
237: | `BTYP_intersect ts -> `BTYP_intersect (map aux ts)
238: | `BTYP_typeset ts -> `BTYP_typeset (map aux ts)
239: | `BTYP_function (a,b) -> `BTYP_function (aux a, aux b)
240: | `BTYP_cfunction (a,b) -> `BTYP_cfunction (aux a, aux b)
241: | `BTYP_pointer a -> `BTYP_pointer (aux a)
242: | `BTYP_lvalue a -> `BTYP_lvalue (aux a)
243: | `BTYP_lift a -> `BTYP_lift (aux a)
244: | `BTYP_array (a,b) -> `BTYP_array (aux a, aux b)
245:
246: | `BTYP_record ts ->
247: `BTYP_record (map (fun (s,t) -> s, aux t) ts)
248:
249: | `BTYP_variant ts ->
250: `BTYP_variant (map (fun (s,t) -> s, aux t) ts)
251:
252: | `BTYP_unitsum _
253: | `BTYP_void
254: | `BTYP_fix _
255: | `BTYP_apply _
256: | `BTYP_typefun _
257: | `BTYP_type _
258: | `BTYP_type_tuple _
259: | `BTYP_type_match _
260: | `BTYP_typesetunion _ -> t
261: | `BTYP_typesetintersection _ -> t
262:
263: (* Jay case: not sure *)
264: | `BTYP_case (a,b,c) -> `BTYP_case (aux a, b, aux c)
265: in
266: aux 0 t
267:
268: let var_list_occurs ls t =
269: let yes = ref false in
270: iter (fun i -> yes := !yes || var_i_occurs i t) ls;
271: !yes
272:
273: let lstrip dfns t =
274: let rec aux trail t' =
275: let uf t = aux (0::trail) t in
276: match t' with
277: | `BTYP_sum ls -> `BTYP_sum (map uf ls)
278: | `BTYP_tuple ls -> `BTYP_tuple (map uf ls)
279: | `BTYP_array (a,b) -> `BTYP_array (uf a, uf b)
280: | `BTYP_record ts -> `BTYP_record (map (fun (s,t) -> s,uf t) ts)
281: | `BTYP_variant ts -> `BTYP_variant (map (fun (s,t) -> s,uf t) ts)
282:
283: (* I think this is WRONG .. *)
284: | `BTYP_function (a,b) -> `BTYP_function (uf a, uf b)
285: | `BTYP_cfunction (a,b) -> `BTYP_cfunction (uf a, uf b)
286:
287: | `BTYP_pointer a -> `BTYP_pointer (uf a)
288: | `BTYP_lvalue a -> aux (1::trail) a
289: | `BTYP_lift a -> aux (1::trail) a
290: | `BTYP_fix i ->
291: let k = ref i in
292: let j = ref 0 in
293: let trail = ref trail in
294: while !k < 0 do
295: j := !j + hd !trail;
296: trail := tl !trail;
297: incr k
298: done;
299: `BTYP_fix (i + !j)
300:
301: | `BTYP_inst (i,ts) -> `BTYP_inst (i,map uf ts)
302: | _ -> t'
303: in aux [] t
304:
305:
306:
307: (* NOTE: this algorithm unifies EQUATIONS
308: not inequations, therefore it doesn't
309: handle any subtyping
310: *)
311:
312: (* NOTE: at the moment,
313: unification doesn't care about type variable
314: meta types: we should probably require them
315: to be the same for an assignment or fail
316: the unification .. however that requires
317: comparing the metatypes for equality, and to that
318: that right requires unification .. :)
319: *)
320:
321: let rec unification allow_lval dfns
322: (eqns: (btypecode_t * btypecode_t) list)
323: (dvars: IntSet.t)
324: : (int * btypecode_t) list =
325: (*
326: print_endline ( "Dvars = { " ^ catmap ", " si (IntSet.elements dvars) ^ "}");
327: *)
328: let eqns = ref eqns in
329: let mgu = ref [] in
330: let rec loop () : unit =
331: match !eqns with
332: | [] -> ()
333: | h :: t ->
334: eqns := t;
335: let s = ref None in
336: begin match h with
337: | (`BTYP_var (i,mi) as ti), (`BTYP_var (j,mj) as tj)->
338: (*
339: print_endline ("Equated variables " ^ si i ^ " <-> " ^ si j);
340: *)
341:
342: (* meta type have to agree *)
343: if mi <> mj then raise Not_found;
344:
345: if i <> j then
346: if IntSet.mem i dvars then
347: s := Some (i,tj)
348: else if IntSet.mem j dvars then
349: s := Some (j,ti)
350: else raise Not_found
351:
352: | `BTYP_lvalue t1, `BTYP_lvalue t2 ->
353: eqns := (t1,t2) :: !eqns
354:
355: (* This says an argument of type lvalue t can match
356: a parameter of type t -- not the other way around tho
357:
358: This must be done FIRST, before matching against
359: `BTYP_var i, t
360: to ensure t can't be an lvalue
361: *)
362: | t1, `BTYP_lvalue t2 when allow_lval ->
363: eqns := (t1,t2) :: !eqns
364:
365: (*
366: | `BTYP_lvalue t1, t2 when allow_lval ->
367: print_endline "WARNING LVALUE ON LEFT UNEXPECTED ..";
368: eqns := (t1,t2) :: !eqns
369: *)
370:
371: | `BTYP_var (i,_), t
372: | t,`BTYP_var (i,_) ->
373: (*
374: print_endline ("variable assignment " ^ si i ^ " -> " ^ sbt dfns t);
375: *)
376:
377: (* WE SHOULD CHECK THAT t has the right meta type .. but
378: the metatype routine isn't defined yet ..
379: *)
380: if not (IntSet.mem i dvars) then raise Not_found;
381: if var_i_occurs i t
382: then begin
383: (*
384: print_endline
385: (
386: "recursion in unification, terms: " ^
387: match h with (a,b) ->
388: sbt dfns a ^ " = " ^ sbt dfns b
389: );
390: *)
391: s := Some (i, fix i t)
392: end else begin
393: let t = lstrip dfns t in
394: (*
395: print_endline "Adding substitution";
396: *)
397: s := Some (i,t)
398: end
399:
400: | `BTYP_lift t1, `BTYP_lift t2 ->
401: eqns := (t1,t2) :: !eqns
402:
403: | `BTYP_intersect ts,t
404: | t,`BTYP_intersect ts ->
405: iter (function t' -> eqns := (t,t') :: !eqns) ts
406:
407: | `BTYP_pointer t1, `BTYP_pointer t2 ->
408: eqns := (t1,t2) :: !eqns
409:
410: | `BTYP_unitsum i, `BTYP_unitsum j when i = j -> ()
411:
412: | `BTYP_unitsum k, `BTYP_sum ls
413: | `BTYP_sum ls, `BTYP_unitsum k when length ls = k ->
414: iter
415: (function
416: | `BTYP_var _ as v ->
417: eqns := (v,unit_t) :: !eqns
418: | _ -> raise Not_found
419: )
420: ls
421:
422: | `BTYP_array (t11, t12), `BTYP_array (t21, t22)
423: | `BTYP_function (t11, t12), `BTYP_function (t21, t22)
424: | `BTYP_cfunction (t11, t12), `BTYP_cfunction (t21, t22) ->
425: eqns := (t11,t21) :: (t12,t22) :: !eqns
426:
427: | `BTYP_record [],`BTYP_tuple []
428: | `BTYP_tuple [],`BTYP_record [] -> ()
429:
430: | `BTYP_record t1,`BTYP_record t2 ->
431: if length t1 = length t2
432: then begin
433: let rcmp (s1,_) (s2,_) = compare s1 s2 in
434: let t1 = sort rcmp t1 in
435: let t2 = sort rcmp t2 in
436: if (map fst t1) <> (map fst t2) then raise Not_found;
437: let rec merge e a b = match a,b with
438: | [],[] -> e
439: | ah :: at, bh :: bt -> merge ((ah,bh) :: e) at bt
440: | _ -> assert false
441: in
442: eqns := merge !eqns (map snd t1) (map snd t2);
443: s := None
444: end
445: else raise Not_found
446:
447: | `BTYP_variant [],`BTYP_void
448: | `BTYP_void,`BTYP_variant [] -> ()
449:
450: | `BTYP_variant t1,`BTYP_variant t2 ->
451: if length t1 = length t2
452: then begin
453: let rcmp (s1,_) (s2,_) = compare s1 s2 in
454: let t1 = sort rcmp t1 in
455: let t2 = sort rcmp t2 in
456: if (map fst t1) <> (map fst t2) then raise Not_found;
457: let rec merge e a b = match a,b with
458: | [],[] -> e
459: | ah :: at, bh :: bt -> merge ((ah,bh) :: e) at bt
460: | _ -> assert false
461: in
462: eqns := merge !eqns (map snd t1) (map snd t2);
463: s := None
464: end
465: else raise Not_found
466:
467: | `BTYP_type i,`BTYP_type j when i = j -> ()
468: | `BTYP_void,`BTYP_void -> ()
469:
470: | `BTYP_inst (i1,ts1),`BTYP_inst (i2,ts2) ->
471: if i1 <> i2 then raise Not_found
472: else if length ts1 <> length ts2 then raise Not_found
473: else
474: begin
475: let rec merge e a b = match a,b with
476: | [],[] -> e
477: | ah :: at, bh :: bt -> merge ((ah,bh) :: e) at bt
478: | _ -> assert false
479: in
480: eqns := merge !eqns ts1 ts2;
481: s := None
482: end
483:
484: | `BTYP_fix i,`BTYP_fix j ->
485: if i <> j then raise Not_found
486:
487: (* array/tuple sidedness must be preserved in
488: case of lvalue decay, which only affects the
489: RHS term [that is, argument lvalue[t] matches
490: parameter t, but not the other way around]
491: *)
492: | `BTYP_tuple ls, `BTYP_array (ta,`BTYP_unitsum n)
493: when n = length ls ->
494: iter (fun t -> eqns := (t,ta) :: !eqns) ls
495:
496: | `BTYP_array (ta,`BTYP_unitsum n), `BTYP_tuple ls
497: when n = length ls ->
498: iter (fun t -> eqns := (ta,t) :: !eqns) ls
499:
500: (* type tuple is handled same as a tuple type .. not
501: really sure this is right. Certainly, the corresponding
502: terms in both must unify, however possibly we should
503: return distinct MGU for each case for the type tuple,
504: possibly with distinct bindings for the same variable..
505: *)
506:
507: | (`BTYP_type_tuple ls1, `BTYP_type_tuple ls2)
508: | (`BTYP_tuple ls1, `BTYP_tuple ls2)
509: | (`BTYP_sum ls1, `BTYP_sum ls2)
510: when length ls1 = length ls2 ->
511: begin
512: let rec merge e a b = match a,b with
513: | [],[] -> e
514: | ah :: at, bh :: bt -> merge ((ah,bh) :: e) at bt
515: | _ -> assert false
516: in
517: eqns := merge !eqns ls1 ls2;
518: s := None
519: end
520:
521: (* structural, not functional, equality of lambdas by alpha equivalence *)
522: | `BTYP_typefun (p1,r1,b1), `BTYP_typefun (p2,r2,b2)
523: when length p1 = length p2 ->
524: let vs = map2 (fun (i1,_) (i2,t) -> i1,`BTYP_var (i2,t)) p1 p2 in
525: let b1 = list_subst vs b1 in
526: eqns := (b1, b2):: !eqns;
527: s := None
528:
529: | x,y ->
530: (*
531: print_endline ("Terms do not match: " ^ sbt dfns x ^ " <-> " ^ sbt dfns y);
532: *)
533: raise Not_found
534: end
535: ;
536: begin match !s with
537: | None -> ()
538: | Some (i,t) ->
539: (*
540: print_endline ("Substituting " ^ si i ^ " -> " ^ sbt dfns t);
541: *)
542: eqns :=
543: map
544: (fun (a,b) ->
545: term_subst a i t,
546: term_subst b i t
547: )
548: !eqns
549: ;
550: assert(not (mem_assoc i !mgu));
551: mgu :=
552: (i,t) ::
553: (map
554: (fun (j,t') -> j,term_subst t' i t)
555: !mgu
556: )
557: end
558: ;
559: loop ()
560: in
561: loop ();
562: !mgu
563:
564: let find_vars_eqns eqns =
565: let lhs_vars = ref IntSet.empty in
566: let rhs_vars = ref IntSet.empty in
567: iter (fun (l,r) ->
568: lhs_vars := IntSet.union !lhs_vars (vars_in l);
569: rhs_vars := IntSet.union !rhs_vars (vars_in r)
570: )
571: eqns
572: ;
573: !lhs_vars,!rhs_vars
574:
575: let maybe_unification dfns eqns =
576: let l,r = find_vars_eqns eqns in
577: let dvars = IntSet.union l r in
578: try Some (unification false dfns eqns dvars)
579: with Not_found -> None
580:
581: let maybe_matches dfns eqns =
582: let l,r = find_vars_eqns eqns in
583: let dvars = IntSet.union l r in
584: try Some (unification true dfns eqns dvars)
585: with Not_found -> None
586:
587: let maybe_specialisation dfns eqns =
588: let l,_ = find_vars_eqns eqns in
589: try Some (unification true dfns eqns l)
590: with Not_found -> None
591:
592: let unifies dfns t1 t2 =
593: let eqns = [t1,t2] in
594: match maybe_unification dfns eqns with
595: | None -> false
596: | Some _ -> true
597:
598: let ge dfns a b =
599: (*
600: print_endline ("Compare terms " ^ sbt dfns a ^ " >? " ^ sbt dfns b);
601: *)
602: match maybe_specialisation dfns [a,b] with
603: | None -> false
604: | Some mgu ->
605: (*
606: print_endline ("MGU from specialisation = ");
607: iter (fun (i, t) -> print_endline (si i ^ " --> " ^ sbt dfns t)) mgu;
608: print_endline "";
609: *)
610: true
611:
612: let compare_sigs dfns a b =
613: match ge dfns a b, ge dfns b a with
614: | true, true -> `Equal
615: | false, false -> `Incomparable
616: | true, false -> `Greater
617: | false, true -> `Less
618:
619:
620: (* returns true if a and b have an mgu,
621: and also adds each element of the mgu to
622: the varmap if it isn't already present
623: this routine is ONLY to be used for
624: calculating the return types of functions,
625: where we're unifying the type of the
626: return statements... probably fails
627: for generic functions .. since the two
628: kinds of type variables aren't distinguished
629: (Fun ret type var is an unknown type, not a
630: variable one .. it must be eliminated, but
631: type parameters must not be [since they're
632: instantiated to multiple values .. ..])
633:
634: The subtyping rule for lvalues also applies
635: here. An lvalue type for a returned expression
636: is compatible with a non-value function return.
637:
638: The unification algorithm can account for this,
639: it requires the LHS = RHS equation to support
640: an extra 'lvalue' in the RHS, but not the other
641: way around. So the expression type has to be the RHS
642: and the declared type the LHS.
643: *)
644:
645: let do_unify syms a b =
646: let eqns =
647: [
648: varmap_subst syms.varmap a,
649: varmap_subst syms.varmap b
650: ]
651: in
652: let l,r = find_vars_eqns eqns in
653: let dvars = IntSet.union l r in
654: try
655: (*
656: print_endline "Calling unification";
657: *)
658: let mgu = unification true syms.dfns eqns dvars in
659: (*
660: print_endline "mgu=";
661: iter
662: (fun (i, t) ->
663: print_endline (string_of_int i ^ " -> " ^ string_of_btypecode syms.dfns t)
664: )
665: mgu;
666: *)
667:
668: (* This crud is used to find the return types of
669: functions initially marked TYP_none, which really
670: means the type is unknown and should be calculated.
671: The system binds each TYP_none to a SPECIAL type variable,
672: and this code is supposed to store type computed by
673: some random unification in a hashtable for such variables.
674:
675: The variables are marked as SPECIAL by using the same
676: index as the function whose return type is unknown.
677: *)
678: iter
679: (fun (i, t) ->
680: if Hashtbl.mem syms.varmap i
681: then
682: begin
683: (*
684: print_endline "Var already in varmap ..";
685: *)
686: let t' = Hashtbl.find syms.varmap i in
687: if t' <> t then
688: failwith
689: (
690: "[do_unify] binding for type variable " ^ string_of_int i ^
691: " is inconsistent\n"
692: )
693: else ()
694: end
695: else
696: begin
697: match Hashtbl.find syms.dfns i with
698: | { symdef=`SYMDEF_glr _ }
699: | { symdef=`SYMDEF_function _ } ->
700: (*
701: print_endline ("Adding variable " ^ string_of_int i ^ " type " ^ string_of_btypecode syms.dfns t);
702: *)
703: Hashtbl.add syms.varmap i t
704:
705: (* if it's a declared type variable, leave it alone *)
706: | {symdef=`SYMDEF_typevar _ } -> ()
707:
708: | _ ->
709: failwith
710: (
711: "[do_unify] attempt to add non-function return unknown type variable "^
712: si i^", type "^sbt syms.dfns t^" to hashtble"
713: )
714: end
715: )
716: mgu
717: ;
718: true
719: with Not_found -> false
720:
721: let rec memq trail (a,b) = match trail with
722: | [] -> false
723: | (i,j)::t -> i == a && j == b || memq t (a,b)
724:
725: let rec type_eq' dfns allow_lval ltrail ldepth rtrail rdepth trail t1 t2 =
726: (* print_endline (sbt dfns t1 ^ " =? " ^ sbt dfns t2); *)
727: if memq trail (t1,t2) then true
728: else let te a b = type_eq' dfns allow_lval
729: ((ldepth,t1)::ltrail) (ldepth+1)
730: ((rdepth,t2)::rtrail) (rdepth+1)
731: ((t1,t2)::trail)
732: a b
733: in
734: match t1,t2 with
735: | `BTYP_inst (i1,ts1),`BTYP_inst (i2,ts2) ->
736: i1 = i2 &&
737: length ts1 = length ts2 &&
738: fold_left2
739: (fun tr a b -> tr && te a b)
740: true ts1 ts2
741:
742: | `BTYP_unitsum i,`BTYP_unitsum j -> i = j
743:
744: | `BTYP_sum ts1, `BTYP_sum ts2
745: | `BTYP_tuple ts1,`BTYP_tuple ts2 ->
746: if length ts1 = length ts2
747: then
748: fold_left2
749: (fun tr a b -> tr && te a b)
750: true ts1 ts2
751: else false
752:
753: | `BTYP_record [],`BTYP_tuple []
754: | `BTYP_tuple [],`BTYP_record [] -> true
755:
756: | `BTYP_record t1,`BTYP_record t2 ->
757: if length t1 = length t2
758: then begin
759: let rcmp (s1,_) (s2,_) = compare s1 s2 in
760: let t1 = sort rcmp t1 in
761: let t2 = sort rcmp t2 in
762: map fst t1 = map fst t2 &&
763: fold_left2
764: (fun tr a b -> tr && te a b)
765: true (map snd t1) (map snd t2)
766: end else false
767:
768: | `BTYP_variant [],`BTYP_tuple []
769: | `BTYP_tuple [],`BTYP_variant [] -> true
770:
771: | `BTYP_variant t1,`BTYP_variant t2 ->
772: if length t1 = length t2
773: then begin
774: let rcmp (s1,_) (s2,_) = compare s1 s2 in
775: let t1 = sort rcmp t1 in
776: let t2 = sort rcmp t2 in
777: map fst t1 = map fst t2 &&
778: fold_left2
779: (fun tr a b -> tr && te a b)
780: true (map snd t1) (map snd t2)
781: end else false
782:
783:
784: | `BTYP_array (s1,d1),`BTYP_array (s2,d2)
785: | `BTYP_function (s1,d1),`BTYP_function (s2,d2)
786: | `BTYP_cfunction (s1,d1),`BTYP_cfunction (s2,d2)
787: | `BTYP_apply(s1,d1),`BTYP_apply(s2,d2)
788: -> te s1 s2 && te d1 d2
789:
790: (* order is important for lvalues .. *)
791: | `BTYP_array (ta,`BTYP_unitsum n),`BTYP_tuple ts
792: when length ts = n ->
793: fold_left (fun tr t -> tr && te ta t) true ts
794:
795:
796: | `BTYP_tuple ts,`BTYP_array (ta,`BTYP_unitsum n)
797: when length ts = n ->
798: fold_left (fun tr t -> tr && te t ta) true ts
799:
800: | `BTYP_pointer p1,`BTYP_pointer p2
801: -> te p1 p2
802:
803: | `BTYP_lift p1,`BTYP_lift p2
804: | `BTYP_lvalue p1,`BTYP_lvalue p2
805: -> te p1 p2
806:
807: | p1,(`BTYP_lvalue p2 as lt) when allow_lval
808: ->
809: type_eq' dfns allow_lval
810: ltrail ldepth
811: ((rdepth,lt)::rtrail) (rdepth+1)
812: ((p1,lt)::trail)
813: p1 p2
814:
815: | `BTYP_void,`BTYP_void
816: -> true
817:
818: | `BTYP_var i, `BTYP_var j ->
819: i = j
820:
821: | `BTYP_fix i,`BTYP_fix j ->
822: let a = assoc (ldepth+i) ltrail in
823: let b = assoc (rdepth+j) rtrail in
824: (* print_endline "Matching fixpoints"; *)
825: type_eq' dfns allow_lval ltrail ldepth rtrail rdepth trail a b
826:
827: | `BTYP_fix i,t ->
828: (* print_endline "LHS fixpoint"; *)
829: let a = assoc (ldepth+i) ltrail in
830: type_eq' dfns allow_lval ltrail ldepth rtrail rdepth trail a t
831:
832: | t,`BTYP_fix j ->
833: (* print_endline "RHS fixpoint"; *)
834: let b = assoc (rdepth+j) rtrail in
835: type_eq' dfns allow_lval ltrail ldepth rtrail rdepth trail t b
836:
837: | `BTYP_typefun (p1,r1,b1), `BTYP_typefun (p2,r2,b2) ->
838: length p1 = length p2 &&
839: let vs = map2 (fun (i1,_) (i2,t) -> i1,`BTYP_var (i2,t)) p1 p2 in
840: let b1 = list_subst vs b1 in
841: te b1 b2
842:
843: | _ -> false
844:
845: let type_eq dfns t1 t2 = (* print_endline "TYPE EQ"; *)
846: type_eq' dfns false [] 0 [] 0 [] t1 t2
847:
848: let type_match dfns t1 t2 = (* print_endline "TYPE MATCH"; *)
849: type_eq' dfns true [] 0 [] 0 [] t1 t2
850:
851: (* NOTE: only works on explicit fixpoint operators,
852: i.e. it won't work on typedefs: no name lookup,
853: these should be removed first ..
854: another view: only works on non-generative types.
855: *)
856:
857: let unfold dfns t =
858: let rec aux depth t' =
859: let uf t = aux (depth+1) t in
860: match t' with
861: | `BTYP_sum ls -> `BTYP_sum (map uf ls)
862: | `BTYP_tuple ls -> `BTYP_tuple (map uf ls)
863: | `BTYP_record ls -> `BTYP_record (map (fun (s,t) -> s,uf t) ls)
864: | `BTYP_variant ls -> `BTYP_variant (map (fun (s,t) -> s,uf t) ls)
865: | `BTYP_array (a,b) -> `BTYP_array (uf a, uf b)
866: | `BTYP_function (a,b) -> `BTYP_function (uf a, uf b)
867: | `BTYP_cfunction (a,b) -> `BTYP_cfunction (uf a, uf b)
868: | `BTYP_pointer a -> `BTYP_pointer (uf a)
869: | `BTYP_lvalue a -> `BTYP_lvalue (uf a)
870: | `BTYP_lift a -> `BTYP_lift (uf a)
871: | `BTYP_fix i when (-i) = depth -> t
872: | `BTYP_fix i when (-i) > depth ->
873: failwith ("[unfold] Fix point outside term, depth="^string_of_int i)
874:
875: | `BTYP_apply (a,b) -> `BTYP_apply(uf a, uf b)
876: | `BTYP_inst (i,ts) -> `BTYP_inst (i,map uf ts)
877: | `BTYP_typefun (p,r,b) ->
878: `BTYP_typefun (p,r,uf b)
879:
880: | `BTYP_type_match (a,tts) ->
881: let a = uf a in
882: (* don't unfold recursions in patterns yet because we don't
883: know what they mean
884: *)
885: let tts = map (fun (p,x) -> p, uf x) tts in
886: `BTYP_type_match (a,tts)
887:
888: | _ -> t'
889: in aux 0 t
890:
891: exception Found of btypecode_t
892:
893: (* this undoes an unfold: it won't minimise an arbitrary type *)
894: let fold dfns t =
895: let rec aux trail depth t' =
896: let ax t = aux ((depth,t')::trail) (depth+1) t in
897: match t' with
898: | `BTYP_intersect ls
899: | `BTYP_sum ls
900: | `BTYP_inst (_,ls)
901: | `BTYP_tuple ls -> iter ax ls
902: | `BTYP_record ls -> iter (fun (s,t) -> ax t) ls
903: | `BTYP_variant ls -> iter (fun (s,t) -> ax t) ls
904:
905: | `BTYP_array (a,b)
906: | `BTYP_function (a,b) -> ax a; ax b
907: | `BTYP_cfunction (a,b) -> ax a; ax b
908:
909: | `BTYP_pointer a -> ax a
910: | `BTYP_lvalue a -> ax a
911: | `BTYP_lift a -> ax a
912:
913: | `BTYP_void
914: | `BTYP_unitsum _
915: | `BTYP_var _
916: | `BTYP_fix 0 -> ()
917:
918: | `BTYP_fix i ->
919: let k = depth + i in
920: begin try
921: let t'' = assoc k trail in
922: if type_eq dfns t'' t then raise (Found t'')
923: with Not_found -> ()
924: end
925:
926: | `BTYP_apply (a,b) -> ax a; ax b
927:
928: | `BTYP_case (a,b,c) -> ax a; ax c
929:
930: | `BTYP_typesetintersection _
931: | `BTYP_typesetunion _
932: | `BTYP_typeset _
933: | `BTYP_typefun _
934: | `BTYP_type _
935: | `BTYP_type_tuple _
936: | `BTYP_type_match _ -> () (* assume fixpoint can't span these boundaries *)
937: (* failwith ("[fold] unexpected metatype " ^ sbt dfns t') *)
938: in
939: try aux [] 0 t; t
940: with Found t -> t
941:
942: (* produces a unique minimal representation of a type
943: by folding at every node *)
944:
945: let minimise dfns t = match map_btype (fold dfns) t with x -> fold dfns x
946:
947: let var_occurs t =
948: let rec aux' excl t = let aux t = aux' excl t in
949: match t with
950: | `BTYP_intersect ls
951: | `BTYP_typeset ls
952: | `BTYP_typesetintersection ls
953: | `BTYP_typesetunion ls
954: | `BTYP_sum ls
955: | `BTYP_inst (_,ls)
956: | `BTYP_tuple ls -> iter aux ls
957: | `BTYP_record ls -> iter (fun (s,t) -> aux t) ls
958: | `BTYP_variant ls -> iter (fun (s,t) -> aux t) ls
959:
960: | `BTYP_array (a,b)
961: | `BTYP_function (a,b) -> aux a; aux b
962: | `BTYP_cfunction (a,b) -> aux a; aux b
963:
964: | `BTYP_pointer a -> aux a
965: | `BTYP_lvalue a -> aux a
966: | `BTYP_lift a -> aux a
967:
968: | `BTYP_unitsum _
969: | `BTYP_void
970: | `BTYP_fix _ -> ()
971:
972: | `BTYP_var (k,_) -> if not (mem k excl) then raise Not_found
973: | `BTYP_typefun (p,r,b) ->
974: aux' (map fst p @ excl) b
975:
976: | _ -> failwith "[var_occurs] unexpected metatype"
977:
978: in try aux' [] t; false with Not_found -> true
979:
980: let normalise_type t =
981: let counter = ref 0 in
982: let varmap = ref [] in
983: let rec aux t = match map_btype aux t with
984: | `BTYP_record [] -> `BTYP_tuple []
985: | `BTYP_variant [] -> `BTYP_void
986: | `BTYP_var (i,mt) ->
987: `BTYP_var
988: ((
989: match list_index !varmap i with
990: | Some j -> j
991: | None ->
992: let n = !counter in
993: incr counter;
994: varmap := !varmap @ [i];
995: n
996: ),mt)
997: | x -> x
998: in
999: let x = aux t in
1000: !varmap, x
1001:
1002: let ident x = x
1003:
1004: (* not really right! Need to map the types as well,
1005: since we're instantiating a polymorphic term with
1006: a more specialised one
1007:
1008: Also won't substitute into LHS of things like direct_apply.
1009: *)
1010: let expr_term_subst e1 i e2 =
1011: let rec s e = match map_tbexpr ident s ident e with
1012: | `BEXPR_name (j,_),_ when i = j -> e2
1013: | e -> e
1014: in s e1
1015:
1016: let rec expr_unification dfns
1017: (eqns: (tbexpr_t * tbexpr_t) list)
1018: (tdvars: IntSet.t)
1019: (edvars: IntSet.t)
1020: :
1021: (int * btypecode_t) list *
1022: (int * tbexpr_t) list
1023: =
1024: (*
1025: print_endline ( "Tdvars = { " ^ catmap ", " si (IntSet.elements tdvars) ^ "}");
1026: print_endline ( "Edvars = { " ^ catmap ", " si (IntSet.elements edvars) ^ "}");
1027: *)
1028: let teqns = ref [] in
1029: let eqns = ref eqns in
1030: let mgu = ref [] in
1031: let rec loop () : unit =
1032: match !eqns with
1033: | [] -> ()
1034: | h :: t ->
1035: eqns := t;
1036: let s = ref None in
1037: let (lhse,lhst),(rhse,rhst) = h in
1038: teqns := (lhst,rhst) :: !teqns;
1039:
1040: (* WE COULD UNIFY TYPES HERE -- but there is no need!
1041: if the terms unify, the types MUST
1042: We DO need to unify the types -- but only after
1043: we've found matching terms.
1044:
1045: Note: the types in the ts lists DO have to be
1046: unified! It's only the types OF terms that
1047: don't require processing .. since they're just
1048: convenience caches of the term type, which can
1049: be computed directly from the term.
1050: *)
1051: begin match (lhse,rhse) with
1052: | (`BEXPR_name (i,[]) as ei), (`BEXPR_name (j,[]) as ej)->
1053: (*
1054: print_endline ("Equated variables " ^ si i ^ " <-> " ^ si j);
1055: *)
1056:
1057: if i <> j then
1058: if IntSet.mem i edvars then
1059: s := Some (i,(ej,rhst))
1060: else if IntSet.mem j edvars then
1061: s := Some (j,(ei,lhst))
1062: else raise Not_found
1063:
1064: | `BEXPR_name (i,_),x ->
1065: if not (IntSet.mem i edvars) then raise Not_found;
1066: s := Some (i,(x,rhst))
1067:
1068: | x,`BEXPR_name (i,_) ->
1069: if not (IntSet.mem i edvars) then raise Not_found;
1070: s := Some (i,(x,lhst))
1071:
1072: | `BEXPR_apply (f1,e1),`BEXPR_apply(f2,e2) ->
1073: (*
1074: print_endline "matched applications";
1075: *)
1076: eqns := (f1,f2) :: (e1,e2) :: !eqns
1077:
1078: | `BEXPR_closure (i,ts1),`BEXPR_closure(j,ts2) when i = j -> ()
1079:
1080: | `BEXPR_apply_prim _, _
1081: | `BEXPR_apply_direct _, _
1082: | `BEXPR_apply_stack _, _
1083: | _, `BEXPR_apply_prim _
1084: | _, `BEXPR_apply_direct _
1085: | _, `BEXPR_apply_stack _
1086: -> assert false
1087:
1088: (*
1089: | `BEXPR_apply_prim (i,ts1,e1),`BEXPR_apply_prim(j,ts2,e2)
1090: | `BEXPR_apply ( (`BEXPR_closure (i,ts1),_), e1),`BEXPR_apply_prim(j,ts2,e2)
1091: | `BEXPR_apply_prim (i,ts1,e1),`BEXPR_apply( (`BEXPR_closure(j,ts2),_),e2)
1092:
1093: | `BEXPR_apply_direct (i,ts1,e1),`BEXPR_apply_direct(j,ts2,e2)
1094: | `BEXPR_apply ( (`BEXPR_closure (i,ts1),_), e1),`BEXPR_apply_direct(j,ts2,e2)
1095: | `BEXPR_apply_direct (i,ts1,e1),`BEXPR_apply( (`BEXPR_closure(j,ts2),_),e2)
1096: when i = j
1097: ->
1098: assert (length ts1 = length ts2);
1099: teqns := combine ts1 ts2 @ !teqns;
1100: eqns := (e1,e2) :: !eqns
1101:
1102: *)
1103:
1104: | `BEXPR_coerce (e,t),`BEXPR_coerce (e',t') ->
1105: teqns := (t,t') :: !teqns;
1106: eqns := (e,e') :: !eqns
1107:
1108: | (`BEXPR_tuple ls1, `BEXPR_tuple ls2)
1109: when length ls1 = length ls2 ->
1110: begin
1111: let rec merge e a b = match a,b with
1112: | [],[] -> e
1113: | ah :: at, bh :: bt -> merge ((ah,bh) :: e) at bt
1114: | _ -> assert false
1115: in
1116: eqns := merge !eqns ls1 ls2;
1117: s := None
1118: end
1119:
1120: | x,y ->
1121: (* the `BTYP_void is a hack .. *)
1122: (*
1123: print_endline ("Terms do not match: " ^ sbe dfns (x,`BTYP_void) ^ " <-> " ^ sbe dfns (y,`BTYP_void));
1124: *)
1125: raise Not_found
1126: end
1127: ;
1128: begin match !s with
1129: | None -> ()
1130: | Some (i,t) ->
1131: (*
1132: print_endline ("Substituting " ^ si i ^ " -> " ^ sbt dfns t);
1133: *)
1134: eqns :=
1135: map
1136: (fun (a,b) ->
1137: expr_term_subst a i t,
1138: expr_term_subst b i t
1139: )
1140: !eqns
1141: ;
1142: assert(not (mem_assoc i !mgu));
1143: mgu :=
1144: (i,t) ::
1145: (map
1146: (fun (j,t') -> j,expr_term_subst t' i t)
1147: !mgu
1148: )
1149: end
1150: ;
1151: loop ()
1152: in
1153: loop ();
1154: let tmgu = unification true dfns !teqns tdvars in
1155: tmgu,
1156: !mgu
1157:
1158: let setoflist ls = fold_left (fun s i -> IntSet.add i s) IntSet.empty ls
1159:
1160: let expr_maybe_matches (dfns:symbol_table_t)
1161: (tvars:int list) (evars:int list)
1162: (le: tbexpr_t)
1163: (re:tbexpr_t)
1164: :
1165: ((int * btypecode_t) list *
1166: (int * tbexpr_t) list) option
1167: =
1168: let tvars = setoflist tvars in
1169: let evars = setoflist evars in
1170: let eqns = [le,re] in
1171: (*
1172: print_endline ("Expr unify: le = " ^ sbe dfns le ^ "\nre = " ^ sbe dfns re);
1173: *)
1174: try Some (expr_unification dfns eqns tvars evars)
1175: with Not_found -> None
1176:
1177: