// Package urlguard validates URLs before fetching to prevent SSRF (Server-Side
// Request Forgery). It rejects non-http(s) schemes and any host that resolves
// to a private, loopback, link-local, multicast, or unspecified IP address.
//
// Validate is called before each request. SafeDialer returns a *net.Dialer
// that re-validates the resolved IP at socket-connect time — the final defense
// against DNS rebinding attacks.
package urlguard

import (
	"context"
	"fmt"
	"net"
	"net/url"
	"strings"
	"syscall"
	"time"
)

// GuardError is returned when a URL fails validation.
type GuardError struct {
	Code    string
	Message string
}

func (e *GuardError) Error() string {
	return fmt.Sprintf("%s: %s", e.Code, e.Message)
}

// Validate checks if a URL is safe to fetch from an untrusted source.
// It re-resolves DNS on every call (no caching) to defend against DNS rebinding.
func Validate(rawURL string) error {
	if rawURL == "" {
		return &GuardError{Code: "invalid_url", Message: "empty URL"}
	}
	u, err := url.Parse(rawURL)
	if err != nil {
		return &GuardError{Code: "invalid_url", Message: err.Error()}
	}
	if u.Scheme != "http" && u.Scheme != "https" {
		return &GuardError{Code: "invalid_scheme", Message: "only http(s) allowed, got " + u.Scheme}
	}
	host := u.Hostname()
	if host == "" {
		return &GuardError{Code: "invalid_url", Message: "missing host"}
	}
	hostLower := strings.ToLower(host)
	if hostLower == "localhost" || strings.HasSuffix(hostLower, ".localhost") {
		return &GuardError{Code: "private_ip", Message: "localhost is not allowed"}
	}
	ips, err := net.DefaultResolver.LookupIPAddr(context.Background(), host)
	if err != nil {
		return &GuardError{Code: "invalid_url", Message: "DNS resolution failed: " + err.Error()}
	}
	if len(ips) == 0 {
		return &GuardError{Code: "invalid_url", Message: "no IPs resolved"}
	}
	for _, ip := range ips {
		if IsPrivateIP(ip.IP) {
			return &GuardError{
				Code:    "private_ip",
				Message: fmt.Sprintf("host %s resolves to private/reserved IP %s", host, ip.IP),
			}
		}
	}
	return nil
}

// IsPrivateIP reports whether the IP is in any private, loopback, link-local,
// multicast, unspecified, or IPv6 ULA range.
func IsPrivateIP(ip net.IP) bool {
	if ip == nil {
		return true
	}
	if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() ||
		ip.IsInterfaceLocalMulticast() || ip.IsMulticast() || ip.IsUnspecified() {
		return true
	}
	if ip.IsPrivate() {
		return true
	}
	if v4 := ip.To4(); v4 != nil {
		switch {
		case v4[0] == 10:
			return true
		case v4[0] == 172 && v4[1] >= 16 && v4[1] <= 31:
			return true
		case v4[0] == 192 && v4[1] == 168:
			return true
		case v4[0] == 127:
			return true
		case v4[0] == 169 && v4[1] == 254:
			return true
		case v4[0] == 0:
			return true
		}
	}
	return false
}

// SafeDialer returns a *net.Dialer whose Control function rejects any
// connection attempt to a private IP. Use this when handing a dialer to
// chromedp or http.Transport — it's the final SSRF defense at socket time.
func SafeDialer() *net.Dialer {
	return &net.Dialer{
		Timeout:   15 * time.Second,
		KeepAlive: 30 * time.Second,
		Control: func(network, address string, c syscall.RawConn) error {
			host, _, err := net.SplitHostPort(address)
			if err != nil {
				return &GuardError{Code: "invalid_url", Message: err.Error()}
			}
			ip := net.ParseIP(host)
			if ip == nil {
				return &GuardError{Code: "invalid_url", Message: "non-IP in dial address: " + host}
			}
			if IsPrivateIP(ip) {
				return &GuardError{Code: "private_ip", Message: "dial to private IP blocked: " + ip.String()}
			}
			return nil
		},
	}
}
