Source file src/net/dnsclient_unix.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package net
16
17 import (
18 "context"
19 "errors"
20 "io"
21 "math/rand"
22 "os"
23 "sync"
24 "time"
25
26 "golang.org/x/net/dns/dnsmessage"
27 )
28
29 const (
30
31 useTCPOnly = true
32 useUDPOrTCP = false
33 )
34
35 var (
36 errLameReferral = errors.New("lame referral")
37 errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message")
38 errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message")
39 errServerMisbehaving = errors.New("server misbehaving")
40 errInvalidDNSResponse = errors.New("invalid DNS response")
41 errNoAnswerFromDNSServer = errors.New("no answer from DNS server")
42
43
44
45
46 errServerTemporarlyMisbehaving = errors.New("server misbehaving")
47 )
48
49 func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) {
50 id = uint16(rand.Int()) ^ uint16(time.Now().UnixNano())
51 b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true})
52 b.EnableCompression()
53 if err := b.StartQuestions(); err != nil {
54 return 0, nil, nil, err
55 }
56 if err := b.Question(q); err != nil {
57 return 0, nil, nil, err
58 }
59 tcpReq, err = b.Finish()
60 udpReq = tcpReq[2:]
61 l := len(tcpReq) - 2
62 tcpReq[0] = byte(l >> 8)
63 tcpReq[1] = byte(l)
64 return id, udpReq, tcpReq, err
65 }
66
67 func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool {
68 if !respHdr.Response {
69 return false
70 }
71 if reqID != respHdr.ID {
72 return false
73 }
74 if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) {
75 return false
76 }
77 return true
78 }
79
80 func dnsPacketRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
81 if _, err := c.Write(b); err != nil {
82 return dnsmessage.Parser{}, dnsmessage.Header{}, err
83 }
84
85 b = make([]byte, 512)
86 for {
87 n, err := c.Read(b)
88 if err != nil {
89 return dnsmessage.Parser{}, dnsmessage.Header{}, err
90 }
91 var p dnsmessage.Parser
92
93
94
95 h, err := p.Start(b[:n])
96 if err != nil {
97 continue
98 }
99 q, err := p.Question()
100 if err != nil || !checkResponse(id, query, h, q) {
101 continue
102 }
103 return p, h, nil
104 }
105 }
106
107 func dnsStreamRoundTrip(c Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) {
108 if _, err := c.Write(b); err != nil {
109 return dnsmessage.Parser{}, dnsmessage.Header{}, err
110 }
111
112 b = make([]byte, 1280)
113 if _, err := io.ReadFull(c, b[:2]); err != nil {
114 return dnsmessage.Parser{}, dnsmessage.Header{}, err
115 }
116 l := int(b[0])<<8 | int(b[1])
117 if l > len(b) {
118 b = make([]byte, l)
119 }
120 n, err := io.ReadFull(c, b[:l])
121 if err != nil {
122 return dnsmessage.Parser{}, dnsmessage.Header{}, err
123 }
124 var p dnsmessage.Parser
125 h, err := p.Start(b[:n])
126 if err != nil {
127 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
128 }
129 q, err := p.Question()
130 if err != nil {
131 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage
132 }
133 if !checkResponse(id, query, h, q) {
134 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
135 }
136 return p, h, nil
137 }
138
139
140 func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Question, timeout time.Duration, useTCP bool) (dnsmessage.Parser, dnsmessage.Header, error) {
141 q.Class = dnsmessage.ClassINET
142 id, udpReq, tcpReq, err := newRequest(q)
143 if err != nil {
144 return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage
145 }
146 var networks []string
147 if useTCP {
148 networks = []string{"tcp"}
149 } else {
150 networks = []string{"udp", "tcp"}
151 }
152 for _, network := range networks {
153 ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout))
154 defer cancel()
155
156 c, err := r.dial(ctx, network, server)
157 if err != nil {
158 return dnsmessage.Parser{}, dnsmessage.Header{}, err
159 }
160 if d, ok := ctx.Deadline(); ok && !d.IsZero() {
161 c.SetDeadline(d)
162 }
163 var p dnsmessage.Parser
164 var h dnsmessage.Header
165 if _, ok := c.(PacketConn); ok {
166 p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
167 } else {
168 p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
169 }
170 c.Close()
171 if err != nil {
172 return dnsmessage.Parser{}, dnsmessage.Header{}, mapErr(err)
173 }
174 if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone {
175 return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse
176 }
177 if h.Truncated {
178 continue
179 }
180 return p, h, nil
181 }
182 return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer
183 }
184
185
186 func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error {
187 if h.RCode == dnsmessage.RCodeNameError {
188 return errNoSuchHost
189 }
190
191 _, err := p.AnswerHeader()
192 if err != nil && err != dnsmessage.ErrSectionDone {
193 return errCannotUnmarshalDNSMessage
194 }
195
196
197
198 if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone {
199 return errLameReferral
200 }
201
202 if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError {
203
204
205
206
207
208 if h.RCode == dnsmessage.RCodeServerFailure {
209 return errServerTemporarlyMisbehaving
210 }
211 return errServerMisbehaving
212 }
213
214 return nil
215 }
216
217 func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error {
218 for {
219 h, err := p.AnswerHeader()
220 if err == dnsmessage.ErrSectionDone {
221 return errNoSuchHost
222 }
223 if err != nil {
224 return errCannotUnmarshalDNSMessage
225 }
226 if h.Type == qtype {
227 return nil
228 }
229 if err := p.SkipAnswer(); err != nil {
230 return errCannotUnmarshalDNSMessage
231 }
232 }
233 }
234
235
236
237 func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
238 var lastErr error
239 serverOffset := cfg.serverOffset()
240 sLen := uint32(len(cfg.servers))
241
242 n, err := dnsmessage.NewName(name)
243 if err != nil {
244 return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
245 }
246 q := dnsmessage.Question{
247 Name: n,
248 Type: qtype,
249 Class: dnsmessage.ClassINET,
250 }
251
252 for i := 0; i < cfg.attempts; i++ {
253 for j := uint32(0); j < sLen; j++ {
254 server := cfg.servers[(serverOffset+j)%sLen]
255
256 p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP)
257 if err != nil {
258 dnsErr := &DNSError{
259 Err: err.Error(),
260 Name: name,
261 Server: server,
262 }
263 if nerr, ok := err.(Error); ok && nerr.Timeout() {
264 dnsErr.IsTimeout = true
265 }
266
267
268 if _, ok := err.(*OpError); ok {
269 dnsErr.IsTemporary = true
270 }
271 lastErr = dnsErr
272 continue
273 }
274
275 if err := checkHeader(&p, h); err != nil {
276 dnsErr := &DNSError{
277 Err: err.Error(),
278 Name: name,
279 Server: server,
280 }
281 if err == errServerTemporarlyMisbehaving {
282 dnsErr.IsTemporary = true
283 }
284 if err == errNoSuchHost {
285
286
287
288 dnsErr.IsNotFound = true
289 return p, server, dnsErr
290 }
291 lastErr = dnsErr
292 continue
293 }
294
295 err = skipToAnswer(&p, qtype)
296 if err == nil {
297 return p, server, nil
298 }
299 lastErr = &DNSError{
300 Err: err.Error(),
301 Name: name,
302 Server: server,
303 }
304 if err == errNoSuchHost {
305
306
307
308 lastErr.(*DNSError).IsNotFound = true
309 return p, server, lastErr
310 }
311 }
312 }
313 return dnsmessage.Parser{}, "", lastErr
314 }
315
316
317 type resolverConfig struct {
318 initOnce sync.Once
319
320
321
322 ch chan struct{}
323 lastChecked time.Time
324
325 mu sync.RWMutex
326 dnsConfig *dnsConfig
327 }
328
329 var resolvConf resolverConfig
330
331
332 func (conf *resolverConfig) init() {
333
334
335 conf.dnsConfig = systemConf().resolv
336 if conf.dnsConfig == nil {
337 conf.dnsConfig = dnsReadConfig("/etc/resolv.conf")
338 }
339 conf.lastChecked = time.Now()
340
341
342
343 conf.ch = make(chan struct{}, 1)
344 }
345
346
347
348
349 func (conf *resolverConfig) tryUpdate(name string) {
350 conf.initOnce.Do(conf.init)
351
352
353 if !conf.tryAcquireSema() {
354 return
355 }
356 defer conf.releaseSema()
357
358 now := time.Now()
359 if conf.lastChecked.After(now.Add(-5 * time.Second)) {
360 return
361 }
362 conf.lastChecked = now
363
364 var mtime time.Time
365 if fi, err := os.Stat(name); err == nil {
366 mtime = fi.ModTime()
367 }
368 if mtime.Equal(conf.dnsConfig.mtime) {
369 return
370 }
371
372 dnsConf := dnsReadConfig(name)
373 conf.mu.Lock()
374 conf.dnsConfig = dnsConf
375 conf.mu.Unlock()
376 }
377
378 func (conf *resolverConfig) tryAcquireSema() bool {
379 select {
380 case conf.ch <- struct{}{}:
381 return true
382 default:
383 return false
384 }
385 }
386
387 func (conf *resolverConfig) releaseSema() {
388 <-conf.ch
389 }
390
391 func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) {
392 if !isDomainName(name) {
393
394
395
396
397
398 return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
399 }
400 resolvConf.tryUpdate("/etc/resolv.conf")
401 resolvConf.mu.RLock()
402 conf := resolvConf.dnsConfig
403 resolvConf.mu.RUnlock()
404 var (
405 p dnsmessage.Parser
406 server string
407 err error
408 )
409 for _, fqdn := range conf.nameList(name) {
410 p, server, err = r.tryOneName(ctx, conf, fqdn, qtype)
411 if err == nil {
412 break
413 }
414 if nerr, ok := err.(Error); ok && nerr.Temporary() && r.strictErrors() {
415
416
417 break
418 }
419 }
420 if err == nil {
421 return p, server, nil
422 }
423 if err, ok := err.(*DNSError); ok {
424
425
426
427 err.Name = name
428 }
429 return dnsmessage.Parser{}, "", err
430 }
431
432
433
434
435
436 func avoidDNS(name string) bool {
437 if name == "" {
438 return true
439 }
440 if name[len(name)-1] == '.' {
441 name = name[:len(name)-1]
442 }
443 return stringsHasSuffixFold(name, ".onion")
444 }
445
446
447 func (conf *dnsConfig) nameList(name string) []string {
448 if avoidDNS(name) {
449 return nil
450 }
451
452
453 l := len(name)
454 rooted := l > 0 && name[l-1] == '.'
455 if l > 254 || l == 254 && rooted {
456 return nil
457 }
458
459
460 if rooted {
461 return []string{name}
462 }
463
464 hasNdots := count(name, '.') >= conf.ndots
465 name += "."
466 l++
467
468
469 names := make([]string, 0, 1+len(conf.search))
470
471 if hasNdots {
472 names = append(names, name)
473 }
474
475 for _, suffix := range conf.search {
476 if l+len(suffix) <= 254 {
477 names = append(names, name+suffix)
478 }
479 }
480
481 if !hasNdots {
482 names = append(names, name)
483 }
484 return names
485 }
486
487
488
489
490 type hostLookupOrder int
491
492 const (
493
494 hostLookupCgo hostLookupOrder = iota
495 hostLookupFilesDNS
496 hostLookupDNSFiles
497 hostLookupFiles
498 hostLookupDNS
499 )
500
501 var lookupOrderName = map[hostLookupOrder]string{
502 hostLookupCgo: "cgo",
503 hostLookupFilesDNS: "files,dns",
504 hostLookupDNSFiles: "dns,files",
505 hostLookupFiles: "files",
506 hostLookupDNS: "dns",
507 }
508
509 func (o hostLookupOrder) String() string {
510 if s, ok := lookupOrderName[o]; ok {
511 return s
512 }
513 return "hostLookupOrder=" + itoa(int(o)) + "??"
514 }
515
516
517
518
519
520
521
522 func (r *Resolver) goLookupHost(ctx context.Context, name string) (addrs []string, err error) {
523 return r.goLookupHostOrder(ctx, name, hostLookupFilesDNS)
524 }
525
526 func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []string, err error) {
527 if order == hostLookupFilesDNS || order == hostLookupFiles {
528
529 addrs = lookupStaticHost(name)
530 if len(addrs) > 0 || order == hostLookupFiles {
531 return
532 }
533 }
534 ips, _, err := r.goLookupIPCNAMEOrder(ctx, name, order)
535 if err != nil {
536 return
537 }
538 addrs = make([]string, 0, len(ips))
539 for _, ip := range ips {
540 addrs = append(addrs, ip.String())
541 }
542 return
543 }
544
545
546 func goLookupIPFiles(name string) (addrs []IPAddr) {
547 for _, haddr := range lookupStaticHost(name) {
548 haddr, zone := splitHostZone(haddr)
549 if ip := ParseIP(haddr); ip != nil {
550 addr := IPAddr{IP: ip, Zone: zone}
551 addrs = append(addrs, addr)
552 }
553 }
554 sortByRFC6724(addrs)
555 return
556 }
557
558
559
560 func (r *Resolver) goLookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
561 order := systemConf().hostLookupOrder(r, host)
562 addrs, _, err = r.goLookupIPCNAMEOrder(ctx, host, order)
563 return
564 }
565
566 func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, name string, order hostLookupOrder) (addrs []IPAddr, cname dnsmessage.Name, err error) {
567 if order == hostLookupFilesDNS || order == hostLookupFiles {
568 addrs = goLookupIPFiles(name)
569 if len(addrs) > 0 || order == hostLookupFiles {
570 return addrs, dnsmessage.Name{}, nil
571 }
572 }
573 if !isDomainName(name) {
574
575 return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
576 }
577 resolvConf.tryUpdate("/etc/resolv.conf")
578 resolvConf.mu.RLock()
579 conf := resolvConf.dnsConfig
580 resolvConf.mu.RUnlock()
581 type result struct {
582 p dnsmessage.Parser
583 server string
584 error
585 }
586 lane := make(chan result, 1)
587 qtypes := [...]dnsmessage.Type{dnsmessage.TypeA, dnsmessage.TypeAAAA}
588 var queryFn func(fqdn string, qtype dnsmessage.Type)
589 var responseFn func(fqdn string, qtype dnsmessage.Type) result
590 if conf.singleRequest {
591 queryFn = func(fqdn string, qtype dnsmessage.Type) {}
592 responseFn = func(fqdn string, qtype dnsmessage.Type) result {
593 dnsWaitGroup.Add(1)
594 defer dnsWaitGroup.Done()
595 p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
596 return result{p, server, err}
597 }
598 } else {
599 queryFn = func(fqdn string, qtype dnsmessage.Type) {
600 dnsWaitGroup.Add(1)
601 go func(qtype dnsmessage.Type) {
602 p, server, err := r.tryOneName(ctx, conf, fqdn, qtype)
603 lane <- result{p, server, err}
604 dnsWaitGroup.Done()
605 }(qtype)
606 }
607 responseFn = func(fqdn string, qtype dnsmessage.Type) result {
608 return <-lane
609 }
610 }
611 var lastErr error
612 for _, fqdn := range conf.nameList(name) {
613 for _, qtype := range qtypes {
614 queryFn(fqdn, qtype)
615 }
616 hitStrictError := false
617 for _, qtype := range qtypes {
618 result := responseFn(fqdn, qtype)
619 if result.error != nil {
620 if nerr, ok := result.error.(Error); ok && nerr.Temporary() && r.strictErrors() {
621
622 hitStrictError = true
623 lastErr = result.error
624 } else if lastErr == nil || fqdn == name+"." {
625
626 lastErr = result.error
627 }
628 continue
629 }
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646 loop:
647 for {
648 h, err := result.p.AnswerHeader()
649 if err != nil && err != dnsmessage.ErrSectionDone {
650 lastErr = &DNSError{
651 Err: "cannot marshal DNS message",
652 Name: name,
653 Server: result.server,
654 }
655 }
656 if err != nil {
657 break
658 }
659 switch h.Type {
660 case dnsmessage.TypeA:
661 a, err := result.p.AResource()
662 if err != nil {
663 lastErr = &DNSError{
664 Err: "cannot marshal DNS message",
665 Name: name,
666 Server: result.server,
667 }
668 break loop
669 }
670 addrs = append(addrs, IPAddr{IP: IP(a.A[:])})
671
672 case dnsmessage.TypeAAAA:
673 aaaa, err := result.p.AAAAResource()
674 if err != nil {
675 lastErr = &DNSError{
676 Err: "cannot marshal DNS message",
677 Name: name,
678 Server: result.server,
679 }
680 break loop
681 }
682 addrs = append(addrs, IPAddr{IP: IP(aaaa.AAAA[:])})
683
684 default:
685 if err := result.p.SkipAnswer(); err != nil {
686 lastErr = &DNSError{
687 Err: "cannot marshal DNS message",
688 Name: name,
689 Server: result.server,
690 }
691 break loop
692 }
693 continue
694 }
695 if cname.Length == 0 && h.Name.Length != 0 {
696 cname = h.Name
697 }
698 }
699 }
700 if hitStrictError {
701
702
703
704 addrs = nil
705 break
706 }
707 if len(addrs) > 0 {
708 break
709 }
710 }
711 if lastErr, ok := lastErr.(*DNSError); ok {
712
713
714
715 lastErr.Name = name
716 }
717 sortByRFC6724(addrs)
718 if len(addrs) == 0 {
719 if order == hostLookupDNSFiles {
720 addrs = goLookupIPFiles(name)
721 }
722 if len(addrs) == 0 && lastErr != nil {
723 return nil, dnsmessage.Name{}, lastErr
724 }
725 }
726 return addrs, cname, nil
727 }
728
729
730 func (r *Resolver) goLookupCNAME(ctx context.Context, host string) (string, error) {
731 order := systemConf().hostLookupOrder(r, host)
732 _, cname, err := r.goLookupIPCNAMEOrder(ctx, host, order)
733 return cname.String(), err
734 }
735
736
737
738
739
740
741 func (r *Resolver) goLookupPTR(ctx context.Context, addr string) ([]string, error) {
742 names := lookupStaticAddr(addr)
743 if len(names) > 0 {
744 return names, nil
745 }
746 arpa, err := reverseaddr(addr)
747 if err != nil {
748 return nil, err
749 }
750 p, server, err := r.lookup(ctx, arpa, dnsmessage.TypePTR)
751 if err != nil {
752 return nil, err
753 }
754 var ptrs []string
755 for {
756 h, err := p.AnswerHeader()
757 if err == dnsmessage.ErrSectionDone {
758 break
759 }
760 if err != nil {
761 return nil, &DNSError{
762 Err: "cannot marshal DNS message",
763 Name: addr,
764 Server: server,
765 }
766 }
767 if h.Type != dnsmessage.TypePTR {
768 err := p.SkipAnswer()
769 if err != nil {
770 return nil, &DNSError{
771 Err: "cannot marshal DNS message",
772 Name: addr,
773 Server: server,
774 }
775 }
776 continue
777 }
778 ptr, err := p.PTRResource()
779 if err != nil {
780 return nil, &DNSError{
781 Err: "cannot marshal DNS message",
782 Name: addr,
783 Server: server,
784 }
785 }
786 ptrs = append(ptrs, ptr.PTR.String())
787
788 }
789 return ptrs, nil
790 }
791
View as plain text