Source file src/cmd/compile/internal/ssa/gen/rulegen.go
1
2
3
4
5
6
7
8
9
10
11
12 package main
13
14 import (
15 "bufio"
16 "bytes"
17 "flag"
18 "fmt"
19 "go/format"
20 "io"
21 "io/ioutil"
22 "log"
23 "os"
24 "regexp"
25 "sort"
26 "strings"
27 )
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 var (
51 genLog = flag.Bool("log", false, "generate code that logs; for debugging only")
52 )
53
54 type Rule struct {
55 rule string
56 loc string
57 }
58
59 func (r Rule) String() string {
60 return fmt.Sprintf("rule %q at %s", r.rule, r.loc)
61 }
62
63 func normalizeSpaces(s string) string {
64 return strings.Join(strings.Fields(strings.TrimSpace(s)), " ")
65 }
66
67
68 func (r Rule) parse() (match, cond, result string) {
69 s := strings.Split(r.rule, "->")
70 if len(s) != 2 {
71 log.Fatalf("no arrow in %s", r)
72 }
73 match = normalizeSpaces(s[0])
74 result = normalizeSpaces(s[1])
75 cond = ""
76 if i := strings.Index(match, "&&"); i >= 0 {
77 cond = normalizeSpaces(match[i+2:])
78 match = normalizeSpaces(match[:i])
79 }
80 return match, cond, result
81 }
82
83 func genRules(arch arch) { genRulesSuffix(arch, "") }
84 func genSplitLoadRules(arch arch) { genRulesSuffix(arch, "splitload") }
85
86 func genRulesSuffix(arch arch, suff string) {
87
88 text, err := os.Open(arch.name + suff + ".rules")
89 if err != nil {
90 if suff == "" {
91
92 log.Fatalf("can't read rule file: %v", err)
93 }
94
95 return
96 }
97
98
99 blockrules := map[string][]Rule{}
100 oprules := map[string][]Rule{}
101
102
103 scanner := bufio.NewScanner(text)
104 rule := ""
105 var lineno int
106 var ruleLineno int
107 for scanner.Scan() {
108 lineno++
109 line := scanner.Text()
110 if i := strings.Index(line, "//"); i >= 0 {
111
112
113 line = line[:i]
114 }
115 rule += " " + line
116 rule = strings.TrimSpace(rule)
117 if rule == "" {
118 continue
119 }
120 if !strings.Contains(rule, "->") {
121 continue
122 }
123 if ruleLineno == 0 {
124 ruleLineno = lineno
125 }
126 if strings.HasSuffix(rule, "->") {
127 continue
128 }
129 if unbalanced(rule) {
130 continue
131 }
132
133 loc := fmt.Sprintf("%s%s.rules:%d", arch.name, suff, ruleLineno)
134 for _, rule2 := range expandOr(rule) {
135 for _, rule3 := range commute(rule2, arch) {
136 r := Rule{rule: rule3, loc: loc}
137 if rawop := strings.Split(rule3, " ")[0][1:]; isBlock(rawop, arch) {
138 blockrules[rawop] = append(blockrules[rawop], r)
139 } else {
140
141 match, _, _ := r.parse()
142 op, oparch, _, _, _, _ := parseValue(match, arch, loc)
143 opname := fmt.Sprintf("Op%s%s", oparch, op.name)
144 oprules[opname] = append(oprules[opname], r)
145 }
146 }
147 }
148 rule = ""
149 ruleLineno = 0
150 }
151 if err := scanner.Err(); err != nil {
152 log.Fatalf("scanner failed: %v\n", err)
153 }
154 if unbalanced(rule) {
155 log.Fatalf("%s.rules:%d: unbalanced rule: %v\n", arch.name, lineno, rule)
156 }
157
158
159 var ops []string
160 for op := range oprules {
161 ops = append(ops, op)
162 }
163 sort.Strings(ops)
164
165
166 w := new(bytes.Buffer)
167 fmt.Fprintf(w, "// Code generated from gen/%s%s.rules; DO NOT EDIT.\n", arch.name, suff)
168 fmt.Fprintln(w, "// generated with: cd gen; go run *.go")
169 fmt.Fprintln(w)
170 fmt.Fprintln(w, "package ssa")
171 fmt.Fprintln(w, "import \"fmt\"")
172 fmt.Fprintln(w, "import \"math\"")
173 fmt.Fprintln(w, "import \"cmd/internal/obj\"")
174 fmt.Fprintln(w, "import \"cmd/internal/objabi\"")
175 fmt.Fprintln(w, "import \"cmd/compile/internal/types\"")
176 fmt.Fprintln(w, "var _ = fmt.Println // in case not otherwise used")
177 fmt.Fprintln(w, "var _ = math.MinInt8 // in case not otherwise used")
178 fmt.Fprintln(w, "var _ = obj.ANOP // in case not otherwise used")
179 fmt.Fprintln(w, "var _ = objabi.GOROOT // in case not otherwise used")
180 fmt.Fprintln(w, "var _ = types.TypeMem // in case not otherwise used")
181 fmt.Fprintln(w)
182
183 const chunkSize = 10
184
185 fmt.Fprintf(w, "func rewriteValue%s%s(v *Value) bool {\n", arch.name, suff)
186 fmt.Fprintf(w, "switch v.Op {\n")
187 for _, op := range ops {
188 fmt.Fprintf(w, "case %s:\n", op)
189 fmt.Fprint(w, "return ")
190 for chunk := 0; chunk < len(oprules[op]); chunk += chunkSize {
191 if chunk > 0 {
192 fmt.Fprint(w, " || ")
193 }
194 fmt.Fprintf(w, "rewriteValue%s%s_%s_%d(v)", arch.name, suff, op, chunk)
195 }
196 fmt.Fprintln(w)
197 }
198 fmt.Fprintf(w, "}\n")
199 fmt.Fprintf(w, "return false\n")
200 fmt.Fprintf(w, "}\n")
201
202
203
204 for _, op := range ops {
205 for chunk := 0; chunk < len(oprules[op]); chunk += chunkSize {
206 buf := new(bytes.Buffer)
207 var canFail bool
208 endchunk := chunk + chunkSize
209 if endchunk > len(oprules[op]) {
210 endchunk = len(oprules[op])
211 }
212 for i, rule := range oprules[op][chunk:endchunk] {
213 match, cond, result := rule.parse()
214 fmt.Fprintf(buf, "// match: %s\n", match)
215 fmt.Fprintf(buf, "// cond: %s\n", cond)
216 fmt.Fprintf(buf, "// result: %s\n", result)
217
218 canFail = false
219 fmt.Fprintf(buf, "for {\n")
220 pos, _, matchCanFail := genMatch(buf, arch, match, rule.loc)
221 if pos == "" {
222 pos = "v.Pos"
223 }
224 if matchCanFail {
225 canFail = true
226 }
227
228 if cond != "" {
229 fmt.Fprintf(buf, "if !(%s) {\nbreak\n}\n", cond)
230 canFail = true
231 }
232 if !canFail && i+chunk != len(oprules[op])-1 {
233 log.Fatalf("unconditional rule %s is followed by other rules", match)
234 }
235
236 genResult(buf, arch, result, rule.loc, pos)
237 if *genLog {
238 fmt.Fprintf(buf, "logRule(\"%s\")\n", rule.loc)
239 }
240 fmt.Fprintf(buf, "return true\n")
241
242 fmt.Fprintf(buf, "}\n")
243 }
244 if canFail {
245 fmt.Fprintf(buf, "return false\n")
246 }
247
248 body := buf.String()
249
250 hasb := strings.Contains(body, " b.")
251 hasconfig := strings.Contains(body, "config.") || strings.Contains(body, "config)")
252 hasfe := strings.Contains(body, "fe.")
253 hastyps := strings.Contains(body, "typ.")
254 fmt.Fprintf(w, "func rewriteValue%s%s_%s_%d(v *Value) bool {\n", arch.name, suff, op, chunk)
255 if hasb || hasconfig || hasfe || hastyps {
256 fmt.Fprintln(w, "b := v.Block")
257 }
258 if hasconfig {
259 fmt.Fprintln(w, "config := b.Func.Config")
260 }
261 if hasfe {
262 fmt.Fprintln(w, "fe := b.Func.fe")
263 }
264 if hastyps {
265 fmt.Fprintln(w, "typ := &b.Func.Config.Types")
266 }
267 fmt.Fprint(w, body)
268 fmt.Fprintf(w, "}\n")
269 }
270 }
271
272
273
274 fmt.Fprintf(w, "func rewriteBlock%s%s(b *Block) bool {\n", arch.name, suff)
275 fmt.Fprintln(w, "config := b.Func.Config")
276 fmt.Fprintln(w, "typ := &config.Types")
277 fmt.Fprintln(w, "_ = typ")
278 fmt.Fprintln(w, "v := b.Control")
279 fmt.Fprintln(w, "_ = v")
280 fmt.Fprintf(w, "switch b.Kind {\n")
281 ops = nil
282 for op := range blockrules {
283 ops = append(ops, op)
284 }
285 sort.Strings(ops)
286 for _, op := range ops {
287 fmt.Fprintf(w, "case %s:\n", blockName(op, arch))
288 for _, rule := range blockrules[op] {
289 match, cond, result := rule.parse()
290 fmt.Fprintf(w, "// match: %s\n", match)
291 fmt.Fprintf(w, "// cond: %s\n", cond)
292 fmt.Fprintf(w, "// result: %s\n", result)
293
294 _, _, _, aux, s := extract(match)
295
296 loopw := new(bytes.Buffer)
297
298
299 pos := ""
300 checkOp := ""
301 if s[0] != "nil" {
302 if strings.Contains(s[0], "(") {
303 pos, checkOp, _ = genMatch0(loopw, arch, s[0], "v", map[string]struct{}{}, rule.loc)
304 } else {
305 fmt.Fprintf(loopw, "%s := b.Control\n", s[0])
306 }
307 }
308 if aux != "" {
309 fmt.Fprintf(loopw, "%s := b.Aux\n", aux)
310 }
311
312 if cond != "" {
313 fmt.Fprintf(loopw, "if !(%s) {\nbreak\n}\n", cond)
314 }
315
316
317 outop, _, _, aux, t := extract(result)
318 newsuccs := t[1:]
319
320
321 succs := s[1:]
322 m := map[string]bool{}
323 for _, succ := range succs {
324 if m[succ] {
325 log.Fatalf("can't have a repeat successor name %s in %s", succ, rule)
326 }
327 m[succ] = true
328 }
329 for _, succ := range newsuccs {
330 if !m[succ] {
331 log.Fatalf("unknown successor %s in %s", succ, rule)
332 }
333 delete(m, succ)
334 }
335 if len(m) != 0 {
336 log.Fatalf("unmatched successors %v in %s", m, rule)
337 }
338
339 fmt.Fprintf(loopw, "b.Kind = %s\n", blockName(outop, arch))
340 if t[0] == "nil" {
341 fmt.Fprintf(loopw, "b.SetControl(nil)\n")
342 } else {
343 if pos == "" {
344 pos = "v.Pos"
345 }
346 fmt.Fprintf(loopw, "b.SetControl(%s)\n", genResult0(loopw, arch, t[0], new(int), false, false, rule.loc, pos))
347 }
348 if aux != "" {
349 fmt.Fprintf(loopw, "b.Aux = %s\n", aux)
350 } else {
351 fmt.Fprintln(loopw, "b.Aux = nil")
352 }
353
354 succChanged := false
355 for i := 0; i < len(succs); i++ {
356 if succs[i] != newsuccs[i] {
357 succChanged = true
358 }
359 }
360 if succChanged {
361 if len(succs) != 2 {
362 log.Fatalf("changed successors, len!=2 in %s", rule)
363 }
364 if succs[0] != newsuccs[1] || succs[1] != newsuccs[0] {
365 log.Fatalf("can only handle swapped successors in %s", rule)
366 }
367 fmt.Fprintln(loopw, "b.swapSuccessors()")
368 }
369
370 if *genLog {
371 fmt.Fprintf(loopw, "logRule(\"%s\")\n", rule.loc)
372 }
373 fmt.Fprintf(loopw, "return true\n")
374
375 if checkOp != "" {
376 fmt.Fprintf(w, "for v.Op == %s {\n", checkOp)
377 } else {
378 fmt.Fprintf(w, "for {\n")
379 }
380 io.Copy(w, loopw)
381
382 fmt.Fprintf(w, "}\n")
383 }
384 }
385 fmt.Fprintf(w, "}\n")
386 fmt.Fprintf(w, "return false\n")
387 fmt.Fprintf(w, "}\n")
388
389
390 b := w.Bytes()
391 src, err := format.Source(b)
392 if err != nil {
393 fmt.Printf("%s\n", b)
394 panic(err)
395 }
396
397
398 err = ioutil.WriteFile("../rewrite"+arch.name+suff+".go", src, 0666)
399 if err != nil {
400 log.Fatalf("can't write output: %v\n", err)
401 }
402 }
403
404
405
406 func genMatch(w io.Writer, arch arch, match string, loc string) (pos, checkOp string, canFail bool) {
407 return genMatch0(w, arch, match, "v", map[string]struct{}{}, loc)
408 }
409
410 func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, loc string) (pos, checkOp string, canFail bool) {
411 if match[0] != '(' || match[len(match)-1] != ')' {
412 panic("non-compound expr in genMatch0: " + match)
413 }
414 op, oparch, typ, auxint, aux, args := parseValue(match, arch, loc)
415
416 checkOp = fmt.Sprintf("Op%s%s", oparch, op.name)
417
418 if op.faultOnNilArg0 || op.faultOnNilArg1 {
419
420 pos = v + ".Pos"
421 }
422
423 if typ != "" {
424 if !isVariable(typ) {
425
426 fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, typ)
427 canFail = true
428 } else {
429
430 if _, ok := m[typ]; ok {
431
432 fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, typ)
433 canFail = true
434 } else {
435 m[typ] = struct{}{}
436 fmt.Fprintf(w, "%s := %s.Type\n", typ, v)
437 }
438 }
439 }
440
441 if auxint != "" {
442 if !isVariable(auxint) {
443
444 fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, auxint)
445 canFail = true
446 } else {
447
448 if _, ok := m[auxint]; ok {
449 fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, auxint)
450 canFail = true
451 } else {
452 m[auxint] = struct{}{}
453 fmt.Fprintf(w, "%s := %s.AuxInt\n", auxint, v)
454 }
455 }
456 }
457
458 if aux != "" {
459 if !isVariable(aux) {
460
461 fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, aux)
462 canFail = true
463 } else {
464
465 if _, ok := m[aux]; ok {
466 fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, aux)
467 canFail = true
468 } else {
469 m[aux] = struct{}{}
470 fmt.Fprintf(w, "%s := %s.Aux\n", aux, v)
471 }
472 }
473 }
474
475
476 if n := len(args); n > 1 {
477 a := args[n-1]
478 if _, set := m[a]; !set && a != "_" && isVariable(a) {
479 m[a] = struct{}{}
480 fmt.Fprintf(w, "%s := %s.Args[%d]\n", a, v, n-1)
481
482
483 args = args[:n-1]
484 } else {
485 fmt.Fprintf(w, "_ = %s.Args[%d]\n", v, n-1)
486 }
487 }
488 for i, arg := range args {
489 if arg == "_" {
490 continue
491 }
492 if !strings.Contains(arg, "(") {
493
494 if _, ok := m[arg]; ok {
495
496
497
498
499 fmt.Fprintf(w, "if %s != %s.Args[%d] {\nbreak\n}\n", arg, v, i)
500 canFail = true
501 } else {
502
503 m[arg] = struct{}{}
504 fmt.Fprintf(w, "%s := %s.Args[%d]\n", arg, v, i)
505 }
506 continue
507 }
508
509 var argname string
510 colon := strings.Index(arg, ":")
511 openparen := strings.Index(arg, "(")
512 if colon >= 0 && openparen >= 0 && colon < openparen {
513
514 argname = arg[:colon]
515 arg = arg[colon+1:]
516 } else {
517
518 argname = fmt.Sprintf("%s_%d", v, i)
519 }
520 if argname == "b" {
521 log.Fatalf("don't name args 'b', it is ambiguous with blocks")
522 }
523
524 fmt.Fprintf(w, "%s := %s.Args[%d]\n", argname, v, i)
525 w2 := new(bytes.Buffer)
526 argPos, argCheckOp, _ := genMatch0(w2, arch, arg, argname, m, loc)
527 fmt.Fprintf(w, "if %s.Op != %s {\nbreak\n}\n", argname, argCheckOp)
528 io.Copy(w, w2)
529
530 if argPos != "" {
531
532
533
534
535
536 pos = argPos
537 }
538 canFail = true
539 }
540
541 if op.argLength == -1 {
542 fmt.Fprintf(w, "if len(%s.Args) != %d {\nbreak\n}\n", v, len(args))
543 canFail = true
544 }
545 return pos, checkOp, canFail
546 }
547
548 func genResult(w io.Writer, arch arch, result string, loc string, pos string) {
549 move := false
550 if result[0] == '@' {
551
552 s := strings.SplitN(result[1:], " ", 2)
553 fmt.Fprintf(w, "b = %s\n", s[0])
554 result = s[1]
555 move = true
556 }
557 genResult0(w, arch, result, new(int), true, move, loc, pos)
558 }
559 func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move bool, loc string, pos string) string {
560
561
562 if result[0] != '(' {
563
564 if top {
565
566
567
568 fmt.Fprintf(w, "v.reset(OpCopy)\n")
569 fmt.Fprintf(w, "v.Type = %s.Type\n", result)
570 fmt.Fprintf(w, "v.AddArg(%s)\n", result)
571 }
572 return result
573 }
574
575 op, oparch, typ, auxint, aux, args := parseValue(result, arch, loc)
576
577
578 typeOverride := typ != ""
579 if typ == "" && op.typ != "" {
580 typ = typeName(op.typ)
581 }
582
583 var v string
584 if top && !move {
585 v = "v"
586 fmt.Fprintf(w, "v.reset(Op%s%s)\n", oparch, op.name)
587 if typeOverride {
588 fmt.Fprintf(w, "v.Type = %s\n", typ)
589 }
590 } else {
591 if typ == "" {
592 log.Fatalf("sub-expression %s (op=Op%s%s) at %s must have a type", result, oparch, op.name, loc)
593 }
594 v = fmt.Sprintf("v%d", *alloc)
595 *alloc++
596 fmt.Fprintf(w, "%s := b.NewValue0(%s, Op%s%s, %s)\n", v, pos, oparch, op.name, typ)
597 if move && top {
598
599 fmt.Fprintf(w, "v.reset(OpCopy)\n")
600 fmt.Fprintf(w, "v.AddArg(%s)\n", v)
601 }
602 }
603
604 if auxint != "" {
605 fmt.Fprintf(w, "%s.AuxInt = %s\n", v, auxint)
606 }
607 if aux != "" {
608 fmt.Fprintf(w, "%s.Aux = %s\n", v, aux)
609 }
610 for _, arg := range args {
611 x := genResult0(w, arch, arg, alloc, false, move, loc, pos)
612 fmt.Fprintf(w, "%s.AddArg(%s)\n", v, x)
613 }
614
615 return v
616 }
617
618 func split(s string) []string {
619 var r []string
620
621 outer:
622 for s != "" {
623 d := 0
624 var open, close byte
625 nonsp := false
626 for i := 0; i < len(s); i++ {
627 switch {
628 case d == 0 && s[i] == '(':
629 open, close = '(', ')'
630 d++
631 case d == 0 && s[i] == '<':
632 open, close = '<', '>'
633 d++
634 case d == 0 && s[i] == '[':
635 open, close = '[', ']'
636 d++
637 case d == 0 && s[i] == '{':
638 open, close = '{', '}'
639 d++
640 case d == 0 && (s[i] == ' ' || s[i] == '\t'):
641 if nonsp {
642 r = append(r, strings.TrimSpace(s[:i]))
643 s = s[i:]
644 continue outer
645 }
646 case d > 0 && s[i] == open:
647 d++
648 case d > 0 && s[i] == close:
649 d--
650 default:
651 nonsp = true
652 }
653 }
654 if d != 0 {
655 panic("imbalanced expression: " + s)
656 }
657 if nonsp {
658 r = append(r, strings.TrimSpace(s))
659 }
660 break
661 }
662 return r
663 }
664
665
666 func isBlock(name string, arch arch) bool {
667 for _, b := range genericBlocks {
668 if b.name == name {
669 return true
670 }
671 }
672 for _, b := range arch.blocks {
673 if b.name == name {
674 return true
675 }
676 }
677 return false
678 }
679
680 func extract(val string) (op string, typ string, auxint string, aux string, args []string) {
681 val = val[1 : len(val)-1]
682
683
684
685 s := split(val)
686
687
688 op = s[0]
689 for _, a := range s[1:] {
690 switch a[0] {
691 case '<':
692 typ = a[1 : len(a)-1]
693 case '[':
694 auxint = a[1 : len(a)-1]
695 case '{':
696 aux = a[1 : len(a)-1]
697 default:
698 args = append(args, a)
699 }
700 }
701 return
702 }
703
704
705
706
707
708 func parseValue(val string, arch arch, loc string) (op opData, oparch string, typ string, auxint string, aux string, args []string) {
709
710 var s string
711 s, typ, auxint, aux, args = extract(val)
712
713
714
715
716
717
718
719 match := func(x opData, strict bool, archname string) bool {
720 if x.name != s {
721 return false
722 }
723 if x.argLength != -1 && int(x.argLength) != len(args) {
724 if strict {
725 return false
726 } else {
727 log.Printf("%s: op %s (%s) should have %d args, has %d", loc, s, archname, x.argLength, len(args))
728 }
729 }
730 return true
731 }
732
733 for _, x := range genericOps {
734 if match(x, true, "generic") {
735 op = x
736 break
737 }
738 }
739 if arch.name != "generic" {
740 for _, x := range arch.ops {
741 if match(x, true, arch.name) {
742 if op.name != "" {
743 log.Fatalf("%s: matches for op %s found in both generic and %s", loc, op.name, arch.name)
744 }
745 op = x
746 oparch = arch.name
747 break
748 }
749 }
750 }
751
752 if op.name == "" {
753
754
755
756 for _, x := range genericOps {
757 match(x, false, "generic")
758 }
759 for _, x := range arch.ops {
760 match(x, false, arch.name)
761 }
762 log.Fatalf("%s: unknown op %s", loc, s)
763 }
764
765
766 if auxint != "" {
767 switch op.aux {
768 case "Bool", "Int8", "Int16", "Int32", "Int64", "Int128", "Float32", "Float64", "SymOff", "SymValAndOff", "SymInt32", "TypSize":
769 default:
770 log.Fatalf("%s: op %s %s can't have auxint", loc, op.name, op.aux)
771 }
772 }
773 if aux != "" {
774 switch op.aux {
775 case "String", "Sym", "SymOff", "SymValAndOff", "SymInt32", "Typ", "TypSize", "CCop":
776 default:
777 log.Fatalf("%s: op %s %s can't have aux", loc, op.name, op.aux)
778 }
779 }
780
781 return
782 }
783
784 func blockName(name string, arch arch) string {
785 for _, b := range genericBlocks {
786 if b.name == name {
787 return "Block" + name
788 }
789 }
790 return "Block" + arch.name + name
791 }
792
793
794 func typeName(typ string) string {
795 if typ[0] == '(' {
796 ts := strings.Split(typ[1:len(typ)-1], ",")
797 if len(ts) != 2 {
798 panic("Tuple expect 2 arguments")
799 }
800 return "types.NewTuple(" + typeName(ts[0]) + ", " + typeName(ts[1]) + ")"
801 }
802 switch typ {
803 case "Flags", "Mem", "Void", "Int128":
804 return "types.Type" + typ
805 default:
806 return "typ." + typ
807 }
808 }
809
810
811 func unbalanced(s string) bool {
812 var left, right int
813 for _, c := range s {
814 if c == '(' {
815 left++
816 }
817 if c == ')' {
818 right++
819 }
820 }
821 return left != right
822 }
823
824
825 func isVariable(s string) bool {
826 b, err := regexp.MatchString("^[A-Za-z_][A-Za-z_0-9]*$", s)
827 if err != nil {
828 panic("bad variable regexp")
829 }
830 return b
831 }
832
833
834 var opRegexp = regexp.MustCompile(`[(](\w+[|])+\w+[)]`)
835
836
837
838
839
840 func excludeFromExpansion(s string, idx []int) bool {
841 left := s[:idx[0]]
842 if strings.LastIndexByte(left, '[') > strings.LastIndexByte(left, ']') {
843
844 return true
845 }
846 right := s[idx[1]:]
847 if strings.Contains(left, "&&") && strings.Contains(right, "->") {
848
849 return true
850 }
851 return false
852 }
853
854
855 func expandOr(r string) []string {
856
857
858
859
860
861 n := 1
862 for _, idx := range opRegexp.FindAllStringIndex(r, -1) {
863 if excludeFromExpansion(r, idx) {
864 continue
865 }
866 s := r[idx[0]:idx[1]]
867 c := strings.Count(s, "|") + 1
868 if c == 1 {
869 continue
870 }
871 if n > 1 && n != c {
872 log.Fatalf("'|' count doesn't match in %s: both %d and %d\n", r, n, c)
873 }
874 n = c
875 }
876 if n == 1 {
877
878 return []string{r}
879 }
880
881 res := make([]string, n)
882 for i := 0; i < n; i++ {
883 buf := new(strings.Builder)
884 x := 0
885 for _, idx := range opRegexp.FindAllStringIndex(r, -1) {
886 if excludeFromExpansion(r, idx) {
887 continue
888 }
889 buf.WriteString(r[x:idx[0]])
890 s := r[idx[0]+1 : idx[1]-1]
891 buf.WriteString(strings.Split(s, "|")[i])
892 x = idx[1]
893 }
894 buf.WriteString(r[x:])
895 res[i] = buf.String()
896 }
897 return res
898 }
899
900
901
902
903 func commute(r string, arch arch) []string {
904 match, cond, result := Rule{rule: r}.parse()
905 a := commute1(match, varCount(match), arch)
906 for i, m := range a {
907 if cond != "" {
908 m += " && " + cond
909 }
910 m += " -> " + result
911 a[i] = m
912 }
913 if len(a) == 1 && normalizeWhitespace(r) != normalizeWhitespace(a[0]) {
914 fmt.Println(normalizeWhitespace(r))
915 fmt.Println(normalizeWhitespace(a[0]))
916 panic("commute() is not the identity for noncommuting rule")
917 }
918 if false && len(a) > 1 {
919 fmt.Println(r)
920 for _, x := range a {
921 fmt.Println(" " + x)
922 }
923 }
924 return a
925 }
926
927 func commute1(m string, cnt map[string]int, arch arch) []string {
928 if m[0] == '<' || m[0] == '[' || m[0] == '{' || isVariable(m) {
929 return []string{m}
930 }
931
932 var prefix string
933 colon := strings.Index(m, ":")
934 if colon >= 0 && isVariable(m[:colon]) {
935 prefix = m[:colon+1]
936 m = m[colon+1:]
937 }
938 if m[0] != '(' || m[len(m)-1] != ')' {
939 panic("non-compound expr in commute1: " + m)
940 }
941 s := split(m[1 : len(m)-1])
942 op := s[0]
943
944
945 commutative := false
946 for _, x := range genericOps {
947 if op == x.name {
948 if x.commutative {
949 commutative = true
950 }
951 break
952 }
953 }
954 if arch.name != "generic" {
955 for _, x := range arch.ops {
956 if op == x.name {
957 if x.commutative {
958 commutative = true
959 }
960 break
961 }
962 }
963 }
964 var idx0, idx1 int
965 if commutative {
966
967 for i, arg := range s {
968 if i == 0 || arg[0] == '<' || arg[0] == '[' || arg[0] == '{' {
969 continue
970 }
971 if idx0 == 0 {
972 idx0 = i
973 continue
974 }
975 if idx1 == 0 {
976 idx1 = i
977 break
978 }
979 }
980 if idx1 == 0 {
981 panic("couldn't find first two args of commutative op " + s[0])
982 }
983 if cnt[s[idx0]] == 1 && cnt[s[idx1]] == 1 || s[idx0] == s[idx1] && cnt[s[idx0]] == 2 {
984
985
986 commutative = false
987 }
988 }
989
990
991 a := make([][]string, len(s))
992 for i, arg := range s {
993 a[i] = commute1(arg, cnt, arch)
994 }
995
996
997 r := crossProduct(a)
998
999
1000 if commutative {
1001 a[idx0], a[idx1] = a[idx1], a[idx0]
1002 r = append(r, crossProduct(a)...)
1003 }
1004
1005
1006 for i, x := range r {
1007 r[i] = prefix + "(" + x + ")"
1008 }
1009 return r
1010 }
1011
1012
1013
1014 func varCount(m string) map[string]int {
1015 cnt := map[string]int{}
1016 varCount1(m, cnt)
1017 return cnt
1018 }
1019 func varCount1(m string, cnt map[string]int) {
1020 if m[0] == '<' || m[0] == '[' || m[0] == '{' {
1021 return
1022 }
1023 if isVariable(m) {
1024 cnt[m]++
1025 return
1026 }
1027
1028 colon := strings.Index(m, ":")
1029 if colon >= 0 && isVariable(m[:colon]) {
1030 cnt[m[:colon]]++
1031 m = m[colon+1:]
1032 }
1033 if m[0] != '(' || m[len(m)-1] != ')' {
1034 panic("non-compound expr in commute1: " + m)
1035 }
1036 s := split(m[1 : len(m)-1])
1037 for _, arg := range s[1:] {
1038 varCount1(arg, cnt)
1039 }
1040 }
1041
1042
1043
1044
1045 func crossProduct(x [][]string) []string {
1046 if len(x) == 1 {
1047 return x[0]
1048 }
1049 var r []string
1050 for _, tail := range crossProduct(x[1:]) {
1051 for _, first := range x[0] {
1052 r = append(r, first+" "+tail)
1053 }
1054 }
1055 return r
1056 }
1057
1058
1059 func normalizeWhitespace(x string) string {
1060 x = strings.Join(strings.Fields(x), " ")
1061 x = strings.Replace(x, "( ", "(", -1)
1062 x = strings.Replace(x, " )", ")", -1)
1063 x = strings.Replace(x, ")->", ") ->", -1)
1064 return x
1065 }
1066
View as plain text