...

Source file src/pkg/cmd/compile/internal/ssa/poset.go

     1	// Copyright 2018 The Go Authors. All rights reserved.
     2	// Use of this source code is governed by a BSD-style
     3	// license that can be found in the LICENSE file.
     4	
     5	package ssa
     6	
     7	import (
     8		"errors"
     9		"fmt"
    10		"os"
    11	)
    12	
    13	const uintSize = 32 << (^uint(0) >> 32 & 1) // 32 or 64
    14	
    15	// bitset is a bit array for dense indexes.
    16	type bitset []uint
    17	
    18	func newBitset(n int) bitset {
    19		return make(bitset, (n+uintSize-1)/uintSize)
    20	}
    21	
    22	func (bs bitset) Reset() {
    23		for i := range bs {
    24			bs[i] = 0
    25		}
    26	}
    27	
    28	func (bs bitset) Set(idx uint32) {
    29		bs[idx/uintSize] |= 1 << (idx % uintSize)
    30	}
    31	
    32	func (bs bitset) Clear(idx uint32) {
    33		bs[idx/uintSize] &^= 1 << (idx % uintSize)
    34	}
    35	
    36	func (bs bitset) Test(idx uint32) bool {
    37		return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
    38	}
    39	
    40	type undoType uint8
    41	
    42	const (
    43		undoInvalid    undoType = iota
    44		undoCheckpoint          // a checkpoint to group undo passes
    45		undoSetChl              // change back left child of undo.idx to undo.edge
    46		undoSetChr              // change back right child of undo.idx to undo.edge
    47		undoNonEqual            // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
    48		undoNewNode             // remove new node created for SSA value undo.ID
    49		undoAliasNode           // unalias SSA value undo.ID so that it points back to node index undo.idx
    50		undoNewRoot             // remove node undo.idx from root list
    51		undoChangeRoot          // remove node undo.idx from root list, and put back undo.edge.Target instead
    52		undoMergeRoot           // remove node undo.idx from root list, and put back its children instead
    53	)
    54	
    55	// posetUndo represents an undo pass to be performed.
    56	// It's an union of fields that can be used to store information,
    57	// and typ is the discriminant, that specifies which kind
    58	// of operation must be performed. Not all fields are always used.
    59	type posetUndo struct {
    60		typ  undoType
    61		idx  uint32
    62		ID   ID
    63		edge posetEdge
    64	}
    65	
    66	const (
    67		// Make poset handle constants as unsigned numbers.
    68		posetFlagUnsigned = 1 << iota
    69	)
    70	
    71	// A poset edge. The zero value is the null/empty edge.
    72	// Packs target node index (31 bits) and strict flag (1 bit).
    73	type posetEdge uint32
    74	
    75	func newedge(t uint32, strict bool) posetEdge {
    76		s := uint32(0)
    77		if strict {
    78			s = 1
    79		}
    80		return posetEdge(t<<1 | s)
    81	}
    82	func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
    83	func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
    84	func (e posetEdge) String() string {
    85		s := fmt.Sprint(e.Target())
    86		if e.Strict() {
    87			s += "*"
    88		}
    89		return s
    90	}
    91	
    92	// posetNode is a node of a DAG within the poset.
    93	type posetNode struct {
    94		l, r posetEdge
    95	}
    96	
    97	// poset is a union-find data structure that can represent a partially ordered set
    98	// of SSA values. Given a binary relation that creates a partial order (eg: '<'),
    99	// clients can record relations between SSA values using SetOrder, and later
   100	// check relations (in the transitive closure) with Ordered. For instance,
   101	// if SetOrder is called to record that A<B and B<C, Ordered will later confirm
   102	// that A<C.
   103	//
   104	// It is possible to record equality relations between SSA values with SetEqual and check
   105	// equality with Equal. Equality propagates into the transitive closure for the partial
   106	// order so that if we know that A<B<C and later learn that A==D, Ordered will return
   107	// true for D<C.
   108	//
   109	// poset will refuse to record new relations that contradict existing relations:
   110	// for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
   111	// calling SetEqual for C==A will fail.
   112	//
   113	// It is also possible to record inequality relations between nodes with SetNonEqual;
   114	// given that non-equality is not transitive, the only effect is that a later call
   115	// to SetEqual for the same values will fail. NonEqual checks whether it is known that
   116	// the nodes are different, either because SetNonEqual was called before, or because
   117	// we know that they are strictly ordered.
   118	//
   119	// It is implemented as a forest of DAGs; in each DAG, if there is a path (directed)
   120	// from node A to B, it means that A<B (or A<=B). Equality is represented by mapping
   121	// two SSA values to the same DAG node; when a new equality relation is recorded
   122	// between two existing nodes,the nodes are merged, adjusting incoming and outgoing edges.
   123	//
   124	// Constants are specially treated. When a constant is added to the poset, it is
   125	// immediately linked to other constants already present; so for instance if the
   126	// poset knows that x<=3, and then x is tested against 5, 5 is first added and linked
   127	// 3 (using 3<5), so that the poset knows that x<=3<5; at that point, it is able
   128	// to answer x<5 correctly.
   129	//
   130	// poset is designed to be memory efficient and do little allocations during normal usage.
   131	// Most internal data structures are pre-allocated and flat, so for instance adding a
   132	// new relation does not cause any allocation. For performance reasons,
   133	// each node has only up to two outgoing edges (like a binary tree), so intermediate
   134	// "dummy" nodes are required to represent more than two relations. For instance,
   135	// to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
   136	// following DAG:
   137	//
   138	//         A
   139	//        / \
   140	//       I  dummy
   141	//           /  \
   142	//          J    K
   143	//
   144	type poset struct {
   145		lastidx   uint32        // last generated dense index
   146		flags     uint8         // internal flags
   147		values    map[ID]uint32 // map SSA values to dense indexes
   148		constants []*Value      // record SSA constants together with their value
   149		nodes     []posetNode   // nodes (in all DAGs)
   150		roots     []uint32      // list of root nodes (forest)
   151		noneq     map[ID]bitset // non-equal relations
   152		undo      []posetUndo   // undo chain
   153	}
   154	
   155	func newPoset() *poset {
   156		return &poset{
   157			values:    make(map[ID]uint32),
   158			constants: make([]*Value, 0, 8),
   159			nodes:     make([]posetNode, 1, 16),
   160			roots:     make([]uint32, 0, 4),
   161			noneq:     make(map[ID]bitset),
   162			undo:      make([]posetUndo, 0, 4),
   163		}
   164	}
   165	
   166	func (po *poset) SetUnsigned(uns bool) {
   167		if uns {
   168			po.flags |= posetFlagUnsigned
   169		} else {
   170			po.flags &^= posetFlagUnsigned
   171		}
   172	}
   173	
   174	// Handle children
   175	func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
   176	func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
   177	func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
   178	func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
   179	func (po *poset) children(i uint32) (posetEdge, posetEdge) {
   180		return po.nodes[i].l, po.nodes[i].r
   181	}
   182	
   183	// upush records a new undo step. It can be used for simple
   184	// undo passes that record up to one index and one edge.
   185	func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
   186		po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
   187	}
   188	
   189	// upushnew pushes an undo pass for a new node
   190	func (po *poset) upushnew(id ID, idx uint32) {
   191		po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
   192	}
   193	
   194	// upushneq pushes a new undo pass for a nonequal relation
   195	func (po *poset) upushneq(id1 ID, id2 ID) {
   196		po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: id1, idx: uint32(id2)})
   197	}
   198	
   199	// upushalias pushes a new undo pass for aliasing two nodes
   200	func (po *poset) upushalias(id ID, i2 uint32) {
   201		po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
   202	}
   203	
   204	// addchild adds i2 as direct child of i1.
   205	func (po *poset) addchild(i1, i2 uint32, strict bool) {
   206		i1l, i1r := po.children(i1)
   207		e2 := newedge(i2, strict)
   208	
   209		if i1l == 0 {
   210			po.setchl(i1, e2)
   211			po.upush(undoSetChl, i1, 0)
   212		} else if i1r == 0 {
   213			po.setchr(i1, e2)
   214			po.upush(undoSetChr, i1, 0)
   215		} else {
   216			// If n1 already has two children, add an intermediate dummy
   217			// node to record the relation correctly (without relating
   218			// n2 to other existing nodes). Use a non-deterministic value
   219			// to decide whether to append on the left or the right, to avoid
   220			// creating degenerated chains.
   221			//
   222			//      n1
   223			//     /  \
   224			//   i1l  dummy
   225			//        /   \
   226			//      i1r   n2
   227			//
   228			dummy := po.newnode(nil)
   229			if (i1^i2)&1 != 0 { // non-deterministic
   230				po.setchl(dummy, i1r)
   231				po.setchr(dummy, e2)
   232				po.setchr(i1, newedge(dummy, false))
   233				po.upush(undoSetChr, i1, i1r)
   234			} else {
   235				po.setchl(dummy, i1l)
   236				po.setchr(dummy, e2)
   237				po.setchl(i1, newedge(dummy, false))
   238				po.upush(undoSetChl, i1, i1l)
   239			}
   240		}
   241	}
   242	
   243	// newnode allocates a new node bound to SSA value n.
   244	// If n is nil, this is a dummy node (= only used internally).
   245	func (po *poset) newnode(n *Value) uint32 {
   246		i := po.lastidx + 1
   247		po.lastidx++
   248		po.nodes = append(po.nodes, posetNode{})
   249		if n != nil {
   250			if po.values[n.ID] != 0 {
   251				panic("newnode for Value already inserted")
   252			}
   253			po.values[n.ID] = i
   254			po.upushnew(n.ID, i)
   255		} else {
   256			po.upushnew(0, i)
   257		}
   258		return i
   259	}
   260	
   261	// lookup searches for a SSA value into the forest of DAGS, and return its node.
   262	// Constants are materialized on the fly during lookup.
   263	func (po *poset) lookup(n *Value) (uint32, bool) {
   264		i, f := po.values[n.ID]
   265		if !f && n.isGenericIntConst() {
   266			po.newconst(n)
   267			i, f = po.values[n.ID]
   268		}
   269		return i, f
   270	}
   271	
   272	// newconst creates a node for a constant. It links it to other constants, so
   273	// that n<=5 is detected true when n<=3 is known to be true.
   274	// TODO: this is O(N), fix it.
   275	func (po *poset) newconst(n *Value) {
   276		if !n.isGenericIntConst() {
   277			panic("newconst on non-constant")
   278		}
   279	
   280		// If this is the first constant, put it into a new root, as
   281		// we can't record an existing connection so we don't have
   282		// a specific DAG to add it to.
   283		if len(po.constants) == 0 {
   284			i := po.newnode(n)
   285			po.roots = append(po.roots, i)
   286			po.upush(undoNewRoot, i, 0)
   287			po.constants = append(po.constants, n)
   288			return
   289		}
   290	
   291		// Find the lower and upper bound among existing constants. That is,
   292		// find the higher constant that is lower than the one that we're adding,
   293		// and the lower constant that is higher.
   294		// The loop is duplicated to handle signed and unsigned comparison,
   295		// depending on how the poset was configured.
   296		var lowerptr, higherptr *Value
   297	
   298		if po.flags&posetFlagUnsigned != 0 {
   299			var lower, higher uint64
   300			val1 := n.AuxUnsigned()
   301			for _, ptr := range po.constants {
   302				val2 := ptr.AuxUnsigned()
   303				if val1 == val2 {
   304					po.aliasnode(ptr, n)
   305					return
   306				}
   307				if val2 < val1 && (lowerptr == nil || val2 > lower) {
   308					lower = val2
   309					lowerptr = ptr
   310				} else if val2 > val1 && (higherptr == nil || val2 < higher) {
   311					higher = val2
   312					higherptr = ptr
   313				}
   314			}
   315		} else {
   316			var lower, higher int64
   317			val1 := n.AuxInt
   318			for _, ptr := range po.constants {
   319				val2 := ptr.AuxInt
   320				if val1 == val2 {
   321					po.aliasnode(ptr, n)
   322					return
   323				}
   324				if val2 < val1 && (lowerptr == nil || val2 > lower) {
   325					lower = val2
   326					lowerptr = ptr
   327				} else if val2 > val1 && (higherptr == nil || val2 < higher) {
   328					higher = val2
   329					higherptr = ptr
   330				}
   331			}
   332		}
   333	
   334		if lowerptr == nil && higherptr == nil {
   335			// This should not happen, as at least one
   336			// other constant must exist if we get here.
   337			panic("no constant found")
   338		}
   339	
   340		// Create the new node and connect it to the bounds, so that
   341		// lower < n < higher. We could have found both bounds or only one
   342		// of them, depending on what other constants are present in the poset.
   343		// Notice that we always link constants together, so they
   344		// are always part of the same DAG.
   345		i := po.newnode(n)
   346		switch {
   347		case lowerptr != nil && higherptr != nil:
   348			// Both bounds are present, record lower < n < higher.
   349			po.addchild(po.values[lowerptr.ID], i, true)
   350			po.addchild(i, po.values[higherptr.ID], true)
   351	
   352		case lowerptr != nil:
   353			// Lower bound only, record lower < n.
   354			po.addchild(po.values[lowerptr.ID], i, true)
   355	
   356		case higherptr != nil:
   357			// Higher bound only. To record n < higher, we need
   358			// a dummy root:
   359			//
   360			//        dummy
   361			//        /   \
   362			//      root   \
   363			//       /      n
   364			//     ....    /
   365			//       \    /
   366			//       higher
   367			//
   368			i2 := po.values[higherptr.ID]
   369			r2 := po.findroot(i2)
   370			dummy := po.newnode(nil)
   371			po.changeroot(r2, dummy)
   372			po.upush(undoChangeRoot, dummy, newedge(r2, false))
   373			po.addchild(dummy, r2, false)
   374			po.addchild(dummy, i, false)
   375			po.addchild(i, i2, true)
   376		}
   377	
   378		po.constants = append(po.constants, n)
   379	}
   380	
   381	// aliasnode records that n2 is an alias of n1
   382	func (po *poset) aliasnode(n1, n2 *Value) {
   383		i1 := po.values[n1.ID]
   384		if i1 == 0 {
   385			panic("aliasnode for non-existing node")
   386		}
   387	
   388		i2 := po.values[n2.ID]
   389		if i2 != 0 {
   390			// Rename all references to i2 into i1
   391			// (do not touch i1 itself, otherwise we can create useless self-loops)
   392			for idx, n := range po.nodes {
   393				if uint32(idx) != i1 {
   394					l, r := n.l, n.r
   395					if l.Target() == i2 {
   396						po.setchl(uint32(idx), newedge(i1, l.Strict()))
   397						po.upush(undoSetChl, uint32(idx), l)
   398					}
   399					if r.Target() == i2 {
   400						po.setchr(uint32(idx), newedge(i1, r.Strict()))
   401						po.upush(undoSetChr, uint32(idx), r)
   402					}
   403				}
   404			}
   405	
   406			// Reassign all existing IDs that point to i2 to i1.
   407			// This includes n2.ID.
   408			for k, v := range po.values {
   409				if v == i2 {
   410					po.values[k] = i1
   411					po.upushalias(k, i2)
   412				}
   413			}
   414		} else {
   415			// n2.ID wasn't seen before, so record it as alias to i1
   416			po.values[n2.ID] = i1
   417			po.upushalias(n2.ID, 0)
   418		}
   419	}
   420	
   421	func (po *poset) isroot(r uint32) bool {
   422		for i := range po.roots {
   423			if po.roots[i] == r {
   424				return true
   425			}
   426		}
   427		return false
   428	}
   429	
   430	func (po *poset) changeroot(oldr, newr uint32) {
   431		for i := range po.roots {
   432			if po.roots[i] == oldr {
   433				po.roots[i] = newr
   434				return
   435			}
   436		}
   437		panic("changeroot on non-root")
   438	}
   439	
   440	func (po *poset) removeroot(r uint32) {
   441		for i := range po.roots {
   442			if po.roots[i] == r {
   443				po.roots = append(po.roots[:i], po.roots[i+1:]...)
   444				return
   445			}
   446		}
   447		panic("removeroot on non-root")
   448	}
   449	
   450	// dfs performs a depth-first search within the DAG whose root is r.
   451	// f is the visit function called for each node; if it returns true,
   452	// the search is aborted and true is returned. The root node is
   453	// visited too.
   454	// If strict, ignore edges across a path until at least one
   455	// strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
   456	// a strict walk visits D,E,F.
   457	// If the visit ends, false is returned.
   458	func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
   459		closed := newBitset(int(po.lastidx + 1))
   460		open := make([]uint32, 1, 64)
   461		open[0] = r
   462	
   463		if strict {
   464			// Do a first DFS; walk all paths and stop when we find a strict
   465			// edge, building a "next" list of nodes reachable through strict
   466			// edges. This will be the bootstrap open list for the real DFS.
   467			next := make([]uint32, 0, 64)
   468	
   469			for len(open) > 0 {
   470				i := open[len(open)-1]
   471				open = open[:len(open)-1]
   472	
   473				// Don't visit the same node twice. Notice that all nodes
   474				// across non-strict paths are still visited at least once, so
   475				// a non-strict path can never obscure a strict path to the
   476				// same node.
   477				if !closed.Test(i) {
   478					closed.Set(i)
   479	
   480					l, r := po.children(i)
   481					if l != 0 {
   482						if l.Strict() {
   483							next = append(next, l.Target())
   484						} else {
   485							open = append(open, l.Target())
   486						}
   487					}
   488					if r != 0 {
   489						if r.Strict() {
   490							next = append(next, r.Target())
   491						} else {
   492							open = append(open, r.Target())
   493						}
   494					}
   495				}
   496			}
   497			open = next
   498			closed.Reset()
   499		}
   500	
   501		for len(open) > 0 {
   502			i := open[len(open)-1]
   503			open = open[:len(open)-1]
   504	
   505			if !closed.Test(i) {
   506				if f(i) {
   507					return true
   508				}
   509				closed.Set(i)
   510				l, r := po.children(i)
   511				if l != 0 {
   512					open = append(open, l.Target())
   513				}
   514				if r != 0 {
   515					open = append(open, r.Target())
   516				}
   517			}
   518		}
   519		return false
   520	}
   521	
   522	// Returns true if there is a path from i1 to i2.
   523	// If strict ==  true: if the function returns true, then i1 <  i2.
   524	// If strict == false: if the function returns true, then i1 <= i2.
   525	// If the function returns false, no relation is known.
   526	func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
   527		return po.dfs(i1, strict, func(n uint32) bool {
   528			return n == i2
   529		})
   530	}
   531	
   532	// findroot finds i's root, that is which DAG contains i.
   533	// Returns the root; if i is itself a root, it is returned.
   534	// Panic if i is not in any DAG.
   535	func (po *poset) findroot(i uint32) uint32 {
   536		// TODO(rasky): if needed, a way to speed up this search is
   537		// storing a bitset for each root using it as a mini bloom filter
   538		// of nodes present under that root.
   539		for _, r := range po.roots {
   540			if po.reaches(r, i, false) {
   541				return r
   542			}
   543		}
   544		panic("findroot didn't find any root")
   545	}
   546	
   547	// mergeroot merges two DAGs into one DAG by creating a new dummy root
   548	func (po *poset) mergeroot(r1, r2 uint32) uint32 {
   549		r := po.newnode(nil)
   550		po.setchl(r, newedge(r1, false))
   551		po.setchr(r, newedge(r2, false))
   552		po.changeroot(r1, r)
   553		po.removeroot(r2)
   554		po.upush(undoMergeRoot, r, 0)
   555		return r
   556	}
   557	
   558	// collapsepath marks i1 and i2 as equal and collapses as equal all
   559	// nodes across all paths between i1 and i2. If a strict edge is
   560	// found, the function does not modify the DAG and returns false.
   561	func (po *poset) collapsepath(n1, n2 *Value) bool {
   562		i1, i2 := po.values[n1.ID], po.values[n2.ID]
   563		if po.reaches(i1, i2, true) {
   564			return false
   565		}
   566	
   567		// TODO: for now, only handle the simple case of i2 being child of i1
   568		l, r := po.children(i1)
   569		if l.Target() == i2 || r.Target() == i2 {
   570			po.aliasnode(n1, n2)
   571			po.addchild(i1, i2, false)
   572			return true
   573		}
   574		return true
   575	}
   576	
   577	// Check whether it is recorded that id1!=id2
   578	func (po *poset) isnoneq(id1, id2 ID) bool {
   579		if id1 < id2 {
   580			id1, id2 = id2, id1
   581		}
   582	
   583		// Check if we recorded a non-equal relation before
   584		if bs, ok := po.noneq[id1]; ok && bs.Test(uint32(id2)) {
   585			return true
   586		}
   587		return false
   588	}
   589	
   590	// Record that id1!=id2
   591	func (po *poset) setnoneq(id1, id2 ID) {
   592		if id1 < id2 {
   593			id1, id2 = id2, id1
   594		}
   595		bs := po.noneq[id1]
   596		if bs == nil {
   597			// Given that we record non-equality relations using the
   598			// higher ID as a key, the bitsize will never change size.
   599			// TODO(rasky): if memory is a problem, consider allocating
   600			// a small bitset and lazily grow it when higher IDs arrive.
   601			bs = newBitset(int(id1))
   602			po.noneq[id1] = bs
   603		} else if bs.Test(uint32(id2)) {
   604			// Already recorded
   605			return
   606		}
   607		bs.Set(uint32(id2))
   608		po.upushneq(id1, id2)
   609	}
   610	
   611	// CheckIntegrity verifies internal integrity of a poset. It is intended
   612	// for debugging purposes.
   613	func (po *poset) CheckIntegrity() (err error) {
   614		// Record which index is a constant
   615		constants := newBitset(int(po.lastidx + 1))
   616		for _, c := range po.constants {
   617			if idx, ok := po.values[c.ID]; !ok {
   618				err = errors.New("node missing for constant")
   619				return err
   620			} else {
   621				constants.Set(idx)
   622			}
   623		}
   624	
   625		// Verify that each node appears in a single DAG, and that
   626		// all constants are within the same DAG
   627		var croot uint32
   628		seen := newBitset(int(po.lastidx + 1))
   629		for _, r := range po.roots {
   630			if r == 0 {
   631				err = errors.New("empty root")
   632				return
   633			}
   634	
   635			po.dfs(r, false, func(i uint32) bool {
   636				if seen.Test(i) {
   637					err = errors.New("duplicate node")
   638					return true
   639				}
   640				seen.Set(i)
   641				if constants.Test(i) {
   642					if croot == 0 {
   643						croot = r
   644					} else if croot != r {
   645						err = errors.New("constants are in different DAGs")
   646						return true
   647					}
   648				}
   649				return false
   650			})
   651			if err != nil {
   652				return
   653			}
   654		}
   655	
   656		// Verify that values contain the minimum set
   657		for id, idx := range po.values {
   658			if !seen.Test(idx) {
   659				err = fmt.Errorf("spurious value [%d]=%d", id, idx)
   660				return
   661			}
   662		}
   663	
   664		// Verify that only existing nodes have non-zero children
   665		for i, n := range po.nodes {
   666			if n.l|n.r != 0 {
   667				if !seen.Test(uint32(i)) {
   668					err = fmt.Errorf("children of unknown node %d->%v", i, n)
   669					return
   670				}
   671				if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
   672					err = fmt.Errorf("self-loop on node %d", i)
   673					return
   674				}
   675			}
   676		}
   677	
   678		return
   679	}
   680	
   681	// CheckEmpty checks that a poset is completely empty.
   682	// It can be used for debugging purposes, as a poset is supposed to
   683	// be empty after it's fully rolled back through Undo.
   684	func (po *poset) CheckEmpty() error {
   685		if len(po.nodes) != 1 {
   686			return fmt.Errorf("non-empty nodes list: %v", po.nodes)
   687		}
   688		if len(po.values) != 0 {
   689			return fmt.Errorf("non-empty value map: %v", po.values)
   690		}
   691		if len(po.roots) != 0 {
   692			return fmt.Errorf("non-empty root list: %v", po.roots)
   693		}
   694		if len(po.constants) != 0 {
   695			return fmt.Errorf("non-empty constants: %v", po.constants)
   696		}
   697		if len(po.undo) != 0 {
   698			return fmt.Errorf("non-empty undo list: %v", po.undo)
   699		}
   700		if po.lastidx != 0 {
   701			return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
   702		}
   703		for _, bs := range po.noneq {
   704			for _, x := range bs {
   705				if x != 0 {
   706					return fmt.Errorf("non-empty noneq map")
   707				}
   708			}
   709		}
   710		return nil
   711	}
   712	
   713	// DotDump dumps the poset in graphviz format to file fn, with the specified title.
   714	func (po *poset) DotDump(fn string, title string) error {
   715		f, err := os.Create(fn)
   716		if err != nil {
   717			return err
   718		}
   719		defer f.Close()
   720	
   721		// Create reverse index mapping (taking aliases into account)
   722		names := make(map[uint32]string)
   723		for id, i := range po.values {
   724			s := names[i]
   725			if s == "" {
   726				s = fmt.Sprintf("v%d", id)
   727			} else {
   728				s += fmt.Sprintf(", v%d", id)
   729			}
   730			names[i] = s
   731		}
   732	
   733		// Create constant mapping
   734		consts := make(map[uint32]int64)
   735		for _, v := range po.constants {
   736			idx := po.values[v.ID]
   737			if po.flags&posetFlagUnsigned != 0 {
   738				consts[idx] = int64(v.AuxUnsigned())
   739			} else {
   740				consts[idx] = v.AuxInt
   741			}
   742		}
   743	
   744		fmt.Fprintf(f, "digraph poset {\n")
   745		fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
   746		for ridx, r := range po.roots {
   747			fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
   748			po.dfs(r, false, func(i uint32) bool {
   749				if val, ok := consts[i]; ok {
   750					// Constant
   751					var vals string
   752					if po.flags&posetFlagUnsigned != 0 {
   753						vals = fmt.Sprint(uint64(val))
   754					} else {
   755						vals = fmt.Sprint(int64(val))
   756					}
   757					fmt.Fprintf(f, "\t\tnode%d [shape=box style=filled fillcolor=cadetblue1 label=<%s <font point-size=\"6\">%s [%d]</font>>]\n",
   758						i, vals, names[i], i)
   759				} else {
   760					// Normal SSA value
   761					fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
   762				}
   763				chl, chr := po.children(i)
   764				for _, ch := range []posetEdge{chl, chr} {
   765					if ch != 0 {
   766						if ch.Strict() {
   767							fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
   768						} else {
   769							fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
   770						}
   771					}
   772				}
   773				return false
   774			})
   775			fmt.Fprintf(f, "\t}\n")
   776		}
   777		fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
   778		fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
   779		fmt.Fprintf(f, "\tlabel=%q\n", title)
   780		fmt.Fprintf(f, "}\n")
   781		return nil
   782	}
   783	
   784	// Ordered reports whether n1<n2. It returns false either when it is
   785	// certain that n1<n2 is false, or if there is not enough information
   786	// to tell.
   787	// Complexity is O(n).
   788	func (po *poset) Ordered(n1, n2 *Value) bool {
   789		if n1.ID == n2.ID {
   790			panic("should not call Ordered with n1==n2")
   791		}
   792	
   793		i1, f1 := po.lookup(n1)
   794		i2, f2 := po.lookup(n2)
   795		if !f1 || !f2 {
   796			return false
   797		}
   798	
   799		return i1 != i2 && po.reaches(i1, i2, true)
   800	}
   801	
   802	// Ordered reports whether n1<=n2. It returns false either when it is
   803	// certain that n1<=n2 is false, or if there is not enough information
   804	// to tell.
   805	// Complexity is O(n).
   806	func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
   807		if n1.ID == n2.ID {
   808			panic("should not call Ordered with n1==n2")
   809		}
   810	
   811		i1, f1 := po.lookup(n1)
   812		i2, f2 := po.lookup(n2)
   813		if !f1 || !f2 {
   814			return false
   815		}
   816	
   817		return i1 == i2 || po.reaches(i1, i2, false)
   818	}
   819	
   820	// Equal reports whether n1==n2. It returns false either when it is
   821	// certain that n1==n2 is false, or if there is not enough information
   822	// to tell.
   823	// Complexity is O(1).
   824	func (po *poset) Equal(n1, n2 *Value) bool {
   825		if n1.ID == n2.ID {
   826			panic("should not call Equal with n1==n2")
   827		}
   828	
   829		i1, f1 := po.lookup(n1)
   830		i2, f2 := po.lookup(n2)
   831		return f1 && f2 && i1 == i2
   832	}
   833	
   834	// NonEqual reports whether n1!=n2. It returns false either when it is
   835	// certain that n1!=n2 is false, or if there is not enough information
   836	// to tell.
   837	// Complexity is O(n) (because it internally calls Ordered to see if we
   838	// can infer n1!=n2 from n1<n2 or n2<n1).
   839	func (po *poset) NonEqual(n1, n2 *Value) bool {
   840		if n1.ID == n2.ID {
   841			panic("should not call Equal with n1==n2")
   842		}
   843		if po.isnoneq(n1.ID, n2.ID) {
   844			return true
   845		}
   846	
   847		// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
   848		if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
   849			return true
   850		}
   851	
   852		return false
   853	}
   854	
   855	// setOrder records that n1<n2 or n1<=n2 (depending on strict).
   856	// Implements SetOrder() and SetOrderOrEqual()
   857	func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
   858		// If we are trying to record n1<=n2 but we learned that n1!=n2,
   859		// record n1<n2, as it provides more information.
   860		if !strict && po.isnoneq(n1.ID, n2.ID) {
   861			strict = true
   862		}
   863	
   864		i1, f1 := po.lookup(n1)
   865		i2, f2 := po.lookup(n2)
   866	
   867		switch {
   868		case !f1 && !f2:
   869			// Neither n1 nor n2 are in the poset, so they are not related
   870			// in any way to existing nodes.
   871			// Create a new DAG to record the relation.
   872			i1, i2 = po.newnode(n1), po.newnode(n2)
   873			po.roots = append(po.roots, i1)
   874			po.upush(undoNewRoot, i1, 0)
   875			po.addchild(i1, i2, strict)
   876	
   877		case f1 && !f2:
   878			// n1 is in one of the DAGs, while n2 is not. Add n2 as children
   879			// of n1.
   880			i2 = po.newnode(n2)
   881			po.addchild(i1, i2, strict)
   882	
   883		case !f1 && f2:
   884			// n1 is not in any DAG but n2 is. If n2 is a root, we can put
   885			// n1 in its place as a root; otherwise, we need to create a new
   886			// dummy root to record the relation.
   887			i1 = po.newnode(n1)
   888	
   889			if po.isroot(i2) {
   890				po.changeroot(i2, i1)
   891				po.upush(undoChangeRoot, i1, newedge(i2, strict))
   892				po.addchild(i1, i2, strict)
   893				return true
   894			}
   895	
   896			// Search for i2's root; this requires a O(n) search on all
   897			// DAGs
   898			r := po.findroot(i2)
   899	
   900			// Re-parent as follows:
   901			//
   902			//                  dummy
   903			//     r            /   \
   904			//      \   ===>   r    i1
   905			//      i2          \   /
   906			//                    i2
   907			//
   908			dummy := po.newnode(nil)
   909			po.changeroot(r, dummy)
   910			po.upush(undoChangeRoot, dummy, newedge(r, false))
   911			po.addchild(dummy, r, false)
   912			po.addchild(dummy, i1, false)
   913			po.addchild(i1, i2, strict)
   914	
   915		case f1 && f2:
   916			// If the nodes are aliased, fail only if we're setting a strict order
   917			// (that is, we cannot set n1<n2 if n1==n2).
   918			if i1 == i2 {
   919				return !strict
   920			}
   921	
   922			// Both n1 and n2 are in the poset. This is the complex part of the algorithm
   923			// as we need to find many different cases and DAG shapes.
   924	
   925			// Check if n1 somehow reaches n2
   926			if po.reaches(i1, i2, false) {
   927				// This is the table of all cases we need to handle:
   928				//
   929				//      DAG          New      Action
   930				//      ---------------------------------------------------
   931				// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
   932				// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
   933				// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
   934				// #4:  N1<X<N2   |  N1<N2  | do nothing
   935	
   936				// Check if we're in case #2
   937				if strict && !po.reaches(i1, i2, true) {
   938					po.addchild(i1, i2, true)
   939					return true
   940				}
   941	
   942				// Case #1, #3 o #4: nothing to do
   943				return true
   944			}
   945	
   946			// Check if n2 somehow reaches n1
   947			if po.reaches(i2, i1, false) {
   948				// This is the table of all cases we need to handle:
   949				//
   950				//      DAG           New      Action
   951				//      ---------------------------------------------------
   952				// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
   953				// #6:  N2<=X<=N1  |  N1<N2  | contradiction
   954				// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
   955				// #8:  N2<X<N1    |  N1<N2  | contradiction
   956	
   957				if strict {
   958					// Cases #6 and #8: contradiction
   959					return false
   960				}
   961	
   962				// We're in case #5 or #7. Try to collapse path, and that will
   963				// fail if it realizes that we are in case #7.
   964				return po.collapsepath(n2, n1)
   965			}
   966	
   967			// We don't know of any existing relation between n1 and n2. They could
   968			// be part of the same DAG or not.
   969			// Find their roots to check whether they are in the same DAG.
   970			r1, r2 := po.findroot(i1), po.findroot(i2)
   971			if r1 != r2 {
   972				// We need to merge the two DAGs to record a relation between the nodes
   973				po.mergeroot(r1, r2)
   974			}
   975	
   976			// Connect n1 and n2
   977			po.addchild(i1, i2, strict)
   978		}
   979	
   980		return true
   981	}
   982	
   983	// SetOrder records that n1<n2. Returns false if this is a contradiction
   984	// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   985	func (po *poset) SetOrder(n1, n2 *Value) bool {
   986		if n1.ID == n2.ID {
   987			panic("should not call SetOrder with n1==n2")
   988		}
   989		return po.setOrder(n1, n2, true)
   990	}
   991	
   992	// SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
   993	// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   994	func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
   995		if n1.ID == n2.ID {
   996			panic("should not call SetOrder with n1==n2")
   997		}
   998		return po.setOrder(n1, n2, false)
   999	}
  1000	
  1001	// SetEqual records that n1==n2. Returns false if this is a contradiction
  1002	// (that is, if it is already recorded that n1<n2 or n2<n1).
  1003	// Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
  1004	func (po *poset) SetEqual(n1, n2 *Value) bool {
  1005		if n1.ID == n2.ID {
  1006			panic("should not call Add with n1==n2")
  1007		}
  1008	
  1009		// If we recorded that n1!=n2, this is a contradiction.
  1010		if po.isnoneq(n1.ID, n2.ID) {
  1011			return false
  1012		}
  1013	
  1014		i1, f1 := po.lookup(n1)
  1015		i2, f2 := po.lookup(n2)
  1016	
  1017		switch {
  1018		case !f1 && !f2:
  1019			i1 = po.newnode(n1)
  1020			po.roots = append(po.roots, i1)
  1021			po.upush(undoNewRoot, i1, 0)
  1022			po.aliasnode(n1, n2)
  1023		case f1 && !f2:
  1024			po.aliasnode(n1, n2)
  1025		case !f1 && f2:
  1026			po.aliasnode(n2, n1)
  1027		case f1 && f2:
  1028			if i1 == i2 {
  1029				// Already aliased, ignore
  1030				return true
  1031			}
  1032	
  1033			// If we already knew that n1<=n2, we can collapse the path to
  1034			// record n1==n2 (and viceversa).
  1035			if po.reaches(i1, i2, false) {
  1036				return po.collapsepath(n1, n2)
  1037			}
  1038			if po.reaches(i2, i1, false) {
  1039				return po.collapsepath(n2, n1)
  1040			}
  1041	
  1042			r1 := po.findroot(i1)
  1043			r2 := po.findroot(i2)
  1044			if r1 != r2 {
  1045				// Merge the two DAGs so we can record relations between the nodes
  1046				po.mergeroot(r1, r2)
  1047			}
  1048	
  1049			// Set n2 as alias of n1. This will also update all the references
  1050			// to n2 to become references to n1
  1051			po.aliasnode(n1, n2)
  1052	
  1053			// Connect i2 (now dummy) as child of i1. This allows to keep the correct
  1054			// order with its children.
  1055			po.addchild(i1, i2, false)
  1056		}
  1057		return true
  1058	}
  1059	
  1060	// SetNonEqual records that n1!=n2. Returns false if this is a contradiction
  1061	// (that is, if it is already recorded that n1==n2).
  1062	// Complexity is O(n).
  1063	func (po *poset) SetNonEqual(n1, n2 *Value) bool {
  1064		if n1.ID == n2.ID {
  1065			panic("should not call Equal with n1==n2")
  1066		}
  1067	
  1068		// See if we already know this
  1069		if po.isnoneq(n1.ID, n2.ID) {
  1070			return true
  1071		}
  1072	
  1073		// Check if we're contradicting an existing relation
  1074		if po.Equal(n1, n2) {
  1075			return false
  1076		}
  1077	
  1078		// Record non-equality
  1079		po.setnoneq(n1.ID, n2.ID)
  1080	
  1081		// If we know that i1<=i2 but not i1<i2, learn that as we
  1082		// now know that they are not equal. Do the same for i2<=i1.
  1083		i1, f1 := po.lookup(n1)
  1084		i2, f2 := po.lookup(n2)
  1085		if f1 && f2 {
  1086			if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
  1087				po.addchild(i1, i2, true)
  1088			}
  1089			if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
  1090				po.addchild(i2, i1, true)
  1091			}
  1092		}
  1093	
  1094		return true
  1095	}
  1096	
  1097	// Checkpoint saves the current state of the DAG so that it's possible
  1098	// to later undo this state.
  1099	// Complexity is O(1).
  1100	func (po *poset) Checkpoint() {
  1101		po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
  1102	}
  1103	
  1104	// Undo restores the state of the poset to the previous checkpoint.
  1105	// Complexity depends on the type of operations that were performed
  1106	// since the last checkpoint; each Set* operation creates an undo
  1107	// pass which Undo has to revert with a worst-case complexity of O(n).
  1108	func (po *poset) Undo() {
  1109		if len(po.undo) == 0 {
  1110			panic("empty undo stack")
  1111		}
  1112	
  1113		for len(po.undo) > 0 {
  1114			pass := po.undo[len(po.undo)-1]
  1115			po.undo = po.undo[:len(po.undo)-1]
  1116	
  1117			switch pass.typ {
  1118			case undoCheckpoint:
  1119				return
  1120	
  1121			case undoSetChl:
  1122				po.setchl(pass.idx, pass.edge)
  1123	
  1124			case undoSetChr:
  1125				po.setchr(pass.idx, pass.edge)
  1126	
  1127			case undoNonEqual:
  1128				po.noneq[pass.ID].Clear(pass.idx)
  1129	
  1130			case undoNewNode:
  1131				if pass.idx != po.lastidx {
  1132					panic("invalid newnode index")
  1133				}
  1134				if pass.ID != 0 {
  1135					if po.values[pass.ID] != pass.idx {
  1136						panic("invalid newnode undo pass")
  1137					}
  1138					delete(po.values, pass.ID)
  1139				}
  1140				po.setchl(pass.idx, 0)
  1141				po.setchr(pass.idx, 0)
  1142				po.nodes = po.nodes[:pass.idx]
  1143				po.lastidx--
  1144	
  1145				// If it was the last inserted constant, remove it
  1146				nc := len(po.constants)
  1147				if nc > 0 && po.constants[nc-1].ID == pass.ID {
  1148					po.constants = po.constants[:nc-1]
  1149				}
  1150	
  1151			case undoAliasNode:
  1152				ID, prev := pass.ID, pass.idx
  1153				cur := po.values[ID]
  1154				if prev == 0 {
  1155					// Born as an alias, die as an alias
  1156					delete(po.values, ID)
  1157				} else {
  1158					if cur == prev {
  1159						panic("invalid aliasnode undo pass")
  1160					}
  1161					// Give it back previous value
  1162					po.values[ID] = prev
  1163				}
  1164	
  1165			case undoNewRoot:
  1166				i := pass.idx
  1167				l, r := po.children(i)
  1168				if l|r != 0 {
  1169					panic("non-empty root in undo newroot")
  1170				}
  1171				po.removeroot(i)
  1172	
  1173			case undoChangeRoot:
  1174				i := pass.idx
  1175				l, r := po.children(i)
  1176				if l|r != 0 {
  1177					panic("non-empty root in undo changeroot")
  1178				}
  1179				po.changeroot(i, pass.edge.Target())
  1180	
  1181			case undoMergeRoot:
  1182				i := pass.idx
  1183				l, r := po.children(i)
  1184				po.changeroot(i, l.Target())
  1185				po.roots = append(po.roots, r.Target())
  1186	
  1187			default:
  1188				panic(pass.typ)
  1189			}
  1190		}
  1191	}
  1192	

View as plain text