diff --git a/.gitignore b/.gitignore index 113b199..1eb638d 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ profile.cov # OS files .DS_Store +forge-supervisor-bin diff --git a/forge-supervisor/Dockerfile b/forge-supervisor/Dockerfile new file mode 100644 index 0000000..10faf69 --- /dev/null +++ b/forge-supervisor/Dockerfile @@ -0,0 +1,46 @@ +# Build stage +FROM golang:1.21-alpine AS builder + +# Install certificates for TLS and useradd +RUN apk add --no-cache ca-certificates + +WORKDIR /build + +# Copy go mod files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy source code +COPY . . + +# Create agent user (UID 1000) in builder — we'll copy its entry to scratch +RUN adduser -D -u 1000 -G 1000 agent + +# Build static binary with netgo (no cgo) +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ + -ldflags="-s -w" \ + -installsuffix netgo \ + -tags netgo \ + -o /usr/local/bin/forge-supervisor . + +# Final stage — scratch image (no shell, minimal) +FROM scratch + +# Copy certificates +COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/ + +# Copy supervisor binary +COPY --from=builder /usr/local/bin/forge-supervisor /usr/local/bin/forge-supervisor + +# Copy passwd/group for UID 1000 (so 'id' and 'groups' work for the agent) +COPY --from=builder /etc/passwd /etc/passwd +COPY --from=builder /etc/group /etc/group + +# Create /etc/forge/ directory for policy file (mounted at runtime) +COPY --from=builder /etc/group /etc/group +RUN mkdir -p /etc/forge && touch /etc/forge/egress_allowlist.json + +# forge-supervisor runs as PID 1 (UID 0), agent child runs as UID 1000 via exec.go +ENTRYPOINT ["/usr/local/bin/forge-supervisor"] diff --git a/forge-supervisor/audit.go b/forge-supervisor/audit.go new file mode 100644 index 0000000..52f8911 --- /dev/null +++ b/forge-supervisor/audit.go @@ -0,0 +1,54 @@ +package main + +import ( + "encoding/json" + "log" + "os" + "sync" + "time" +) + +// AuditEvent represents an audit log entry in NDJSON format. +type AuditEvent struct { + Timestamp time.Time `json:"timestamp"` + Action string `json:"action"` // "allowed", "denied", "exit" + Host string `json:"host,omitempty"` + Port int `json:"port,omitempty"` + PID int `json:"pid,omitempty"` + ExitCode int `json:"exit_code,omitempty"` +} + +// AuditLogger writes NDJSON audit events to stdout. +type AuditLogger struct { + mu sync.Mutex +} + +// NewAuditLogger creates a new AuditLogger. +func NewAuditLogger() *AuditLogger { + return &AuditLogger{} +} + +// Log writes an audit event to stdout in NDJSON format. +func (a *AuditLogger) Log(event *AuditEvent) { + a.mu.Lock() + defer a.mu.Unlock() + + data, err := json.Marshal(event) + if err != nil { + log.Printf("ERROR: marshal audit event: %v", err) + return + } + + os.Stdout.Write(data) + os.Stdout.Write([]byte("\n")) +} + +// LogExitEvent logs an agent exit event. +func (a *AuditLogger) LogExitEvent(pid, exitCode int) { + a.Log(&AuditEvent{ + Timestamp: time.Now().UTC(), + Action: "exit", + PID: pid, + ExitCode: exitCode, + }) +} diff --git a/forge-supervisor/exec.go b/forge-supervisor/exec.go new file mode 100644 index 0000000..1a1cec0 --- /dev/null +++ b/forge-supervisor/exec.go @@ -0,0 +1,61 @@ +package main + +import ( + "fmt" + "log" + "os" + "os/exec" + "syscall" + + "golang.org/x/sys/unix" +) + +// ExecAgent forks and executes the agent process as UID 1000. +// The supervisor stays as UID 0 so its own traffic is NOT redirected +// by the iptables OUTPUT chain (which targets UID 1000 only). +// Returns the *os.Process of the child. +func ExecAgent(args []string) (*os.Process, error) { + path, err := exec.LookPath(args[0]) + if err != nil { + return nil, fmt.Errorf("lookpath %q: %w", args[0], err) + } + + cmd := exec.Command(path, args[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Set UID/GID on the child — supervisor stays as UID 0. + // iptables redirects only UID 1000 traffic, so supervisor is unaffected. + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{Uid: 1000, Gid: 1000}, + Setsid: true, + // Setctty only when stdin is a TTY (containers may not have one) + Setctty: isStdinTTY(), + } + + if err := cmd.Start(); err != nil { + return nil, fmt.Errorf("start: %w", err) + } + + log.Printf("INFO: started agent (PID %d) as UID 1000: %s", cmd.Process.Pid, path) + return cmd.Process, nil +} + +// ForwardSignal forwards a signal to the process. +func ForwardSignal(pid int, sig syscall.Signal) { + proc, err := os.FindProcess(pid) + if err != nil { + log.Printf("ERROR: find process %d: %v", pid, err) + return + } + if err := proc.Signal(sig); err != nil { + log.Printf("ERROR: signal %d to %d: %v", sig, pid, err) + } +} + +// isStdinTTY returns true if stdin is a terminal. +func isStdinTTY() bool { + _, err := unix.IoctlGetTermios(int(os.Stdin.Fd()), unix.TIOCGWINSZ) + return err == nil +} diff --git a/forge-supervisor/go.mod b/forge-supervisor/go.mod new file mode 100644 index 0000000..ca22f47 --- /dev/null +++ b/forge-supervisor/go.mod @@ -0,0 +1,11 @@ +module github.com/initializ/forge/forge-supervisor + +go 1.25.0 + +toolchain go1.25.0 + +require ( + github.com/initializ/forge/forge-core v0.0.0 +) + +replace github.com/initializ/forge/forge-core => ../forge-core diff --git a/forge-supervisor/health.go b/forge-supervisor/health.go new file mode 100644 index 0000000..b27575a --- /dev/null +++ b/forge-supervisor/health.go @@ -0,0 +1,72 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "net/http" + "sync" + "time" +) + +// DenialEvent represents a single egress denial event. +type DenialEvent struct { + Timestamp time.Time `json:"timestamp"` + Host string `json:"host"` + Port int `json:"port"` +} + +// DenialTracker stores denial events for the /denials endpoint. +type DenialTracker struct { + mu sync.RWMutex + denials []DenialEvent +} + +// Add records a new denial event. +func (d *DenialTracker) Add(event DenialEvent) { + d.mu.Lock() + defer d.mu.Unlock() + d.denials = append(d.denials, event) + // Keep only the last 1000 denials + if len(d.denials) > 1000 { + d.denials = d.denials[len(d.denials)-1000:] + } +} + +// GetAll returns all recorded denial events. +func (d *DenialTracker) GetAll() []DenialEvent { + d.mu.RLock() + defer d.mu.RUnlock() + result := make([]DenialEvent, len(d.denials)) + copy(result, d.denials) + return result +} + +// StartHealthEndpoints starts HTTP endpoints for health checks. +func StartHealthEndpoints(tracker *DenialTracker, port int) { + mux := http.NewServeMux() + + mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("OK")) + }) + + mux.HandleFunc("/denials", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + denials := tracker.GetAll() + if err := json.NewEncoder(w).Encode(denials); err != nil { + log.Printf("ERROR: encode denials: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + }) + + addr := fmt.Sprintf("127.0.0.1:%d", port) + log.Printf("INFO: health endpoints listening on %s", addr) + + go func() { + if err := http.ListenAndServe(addr, mux); err != nil && err != http.ErrServerClosed { + log.Printf("ERROR: health server: %v", err) + } + }() +} diff --git a/forge-supervisor/http.go b/forge-supervisor/http.go new file mode 100644 index 0000000..94fe0d3 --- /dev/null +++ b/forge-supervisor/http.go @@ -0,0 +1,54 @@ +package main + +import ( + "bufio" + "io" + "net" + "strings" +) + +// ExtractHTTPHost reads an HTTP request from the connection (using initialBytes +// as the start) to find the Host header. Returns consumed bytes (for replay) +// and the hostname. +func ExtractHTTPHost(initialBytes []byte, conn net.Conn) ([]byte, string) { + reader := bufio.NewReader(io.MultiReader( + strings.NewReader(string(initialBytes)), + conn, + )) + + // Read and consume request line + _, err := reader.ReadString('\n') + if err != nil { + return initialBytes, "" + } + + // Read headers + for { + line, err := reader.ReadString('\n') + if err != nil { + return initialBytes, "" + } + + // End of headers + if line == "\r\n" || line == "\n" { + break + } + + // Host: header — line is "Host: value\r\n" + // "Host:" is 5 characters + if strings.HasPrefix(strings.ToLower(line), "host:") { + host := strings.TrimSpace(line[5:]) // Skip "Host:" (5 chars) + host = strings.TrimSuffix(host, "\r") + host = strings.ToLower(host) + // Remove port if present + if idx := strings.Index(host, ":"); idx != -1 { + host = host[:idx] + } + // Consume all bytes up to and including headers + consumed := append([]byte(line), []byte("\r\n")...) + return consumed, host + } + } + + return initialBytes, "" +} diff --git a/forge-supervisor/integration_test.go b/forge-supervisor/integration_test.go new file mode 100644 index 0000000..5bd8f8b --- /dev/null +++ b/forge-supervisor/integration_test.go @@ -0,0 +1,206 @@ +package main + +import ( + "encoding/json" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" +) + +// TestIntegration tests the supervisor by building it and verifying +// the binary starts and responds to health checks. +func TestIntegration(t *testing.T) { + // Skip if not in integration test mode + if os.Getenv("RUN_INTEGRATION_TESTS") != "1" { + t.Skip("Skipping integration test (set RUN_INTEGRATION_TESTS=1 to run)") + } + + // Build the supervisor + cmd := exec.Command("go", "build", "-o", "forge-supervisor-test", ".") + cmd.Dir = filepath.Dir(os.Args[0]) + output, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("failed to build supervisor: %v\noutput: %s", err, string(output)) + } + defer os.Remove("forge-supervisor-test") + + // Create a temporary egress policy file + tempDir := t.TempDir() + policyPath := filepath.Join(tempDir, "egress_allowlist.json") + policy := `{ + "mode": "allowlist", + "allowed_domains": ["example.com", "*.github.com"], + "allow_private_ips": false + }` + if err := os.WriteFile(policyPath, []byte(policy), 0644); err != nil { + t.Fatalf("failed to write policy file: %v", err) + } + + // Copy policy to current dir for the test + defer os.Remove(policyPath) + + // Note: We can't actually run the full supervisor in a test because: + // 1. It needs to be PID 1 + // 2. It needs CAP_NET_ADMIN for iptables + // 3. It needs to fork/exec an agent + // + // Instead, we test the individual components. + + t.Run("PolicyLoading", func(t *testing.T) { + testPolicyLoading(t, policyPath) + }) + + t.Run("DomainMatcher", func(t *testing.T) { + testDomainMatcher(t) + }) + + t.Run("HealthEndpoints", func(t *testing.T) { + testHealthEndpoints(t) + }) + + t.Run("AuditLogger", func(t *testing.T) { + testAuditLogger(t) + }) +} + +func testPolicyLoading(t *testing.T, policyPath string) { + // Write a test policy to a temp location + tmpPolicy := `{ + "mode": "allowlist", + "allowed_domains": ["test.com", "*.example.com"], + "allow_private_ips": false + }` + tmpFile := filepath.Join(t.TempDir(), "test_policy.json") + if err := os.WriteFile(tmpFile, []byte(tmpPolicy), 0644); err != nil { + t.Fatalf("failed to write temp policy: %v", err) + } + + policy, err := LoadPolicy(tmpFile) + if err != nil { + t.Fatalf("LoadPolicy failed: %v", err) + } + + if policy.Mode != "allowlist" { + t.Errorf("expected mode 'allowlist', got %q", policy.Mode) + } + + if len(policy.AllowedDomains) != 2 { + t.Errorf("expected 2 domains, got %d", len(policy.AllowedDomains)) + } +} + +func testDomainMatcher(t *testing.T) { + // Test that we can create a matcher and check domains + // Note: This tests the import from forge-core works + policy := `{ + "mode": "allowlist", + "allowed_domains": ["example.com", "*.github.com"], + "allow_private_ips": false + }` + tmpFile := filepath.Join(t.TempDir(), "test_policy.json") + if err := os.WriteFile(tmpFile, []byte(policy), 0644); err != nil { + t.Fatalf("failed to write temp policy: %v", err) + } + + p, err := LoadPolicy(tmpFile) + if err != nil { + t.Fatalf("LoadPolicy failed: %v", err) + } + + // The matcher is created in main.go, but we can verify the policy + // has the right structure for the matcher + if len(p.AllowedDomains) != 2 { + t.Errorf("expected 2 domains, got %d", len(p.AllowedDomains)) + } +} + +func testHealthEndpoints(t *testing.T) { + // Create a denial tracker and start health endpoints + tracker := &DenialTracker{} + StartHealthEndpoints(tracker, 15000) + + // Give the server time to start + time.Sleep(100 * time.Millisecond) + + // Test /healthz + resp, err := http.Get("http://127.0.0.1:15000/healthz") + if err != nil { + t.Fatalf("healthz request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("healthz returned status %d, expected 200", resp.StatusCode) + } + + // Test /denials + resp, err = http.Get("http://127.0.0.1:15000/denials") + if err != nil { + t.Fatalf("denials request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("denials returned status %d, expected 200", resp.StatusCode) + } + + // Check content type + if ct := resp.Header.Get("Content-Type"); !strings.Contains(ct, "application/json") { + t.Errorf("denials Content-Type = %q, expected application/json", ct) + } + + // Add a denial and verify it's returned + tracker.Add(DenialEvent{ + Timestamp: time.Now(), + Host: "blocked.example.com", + Port: 443, + }) + + resp, err = http.Get("http://127.0.0.1:15000/denials") + if err != nil { + t.Fatalf("denials request failed: %v", err) + } + defer resp.Body.Close() + + var denials []DenialEvent + if err := json.NewDecoder(resp.Body).Decode(&denials); err != nil { + t.Fatalf("failed to decode denials: %v", err) + } + + if len(denials) != 1 { + t.Errorf("expected 1 denial, got %d", len(denials)) + } + + if denials[0].Host != "blocked.example.com" { + t.Errorf("expected host 'blocked.example.com', got %q", denials[0].Host) + } +} + +func testAuditLogger(t *testing.T) { + // Create audit logger and verify it produces NDJSON + _ = NewAuditLogger() + + // This would write to stdout - in tests we verify the struct is correct + event := &AuditEvent{ + Timestamp: time.Now().UTC(), + Action: "allowed", + Host: "example.com", + Port: 443, + } + + // Verify the event can be marshaled to JSON + data, err := json.Marshal(event) + if err != nil { + t.Fatalf("failed to marshal audit event: %v", err) + } + + // Verify it's valid NDJSON (single line JSON) + lines := strings.Split(string(data), "\n") + if len(lines) != 1 { + t.Errorf("expected 1 line, got %d", len(lines)) + } +} diff --git a/forge-supervisor/iptables.go b/forge-supervisor/iptables.go new file mode 100644 index 0000000..1a02ec3 --- /dev/null +++ b/forge-supervisor/iptables.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "fmt" + "log" + "os/exec" + "strings" + "time" +) + +const ( + redirectPort = 15001 + targetUID = "1000" + waitTimeout = 5 * time.Second +) + +// SetupIPTables configures iptables (nat table) to redirect outgoing TCP traffic +// from UID 1000 to the local proxy on redirectPort. Runs as UID 0 (supervisor), +// so its own traffic is NOT redirected — only the agent's UID 1000 traffic is. +// Logs a warning and continues if iptables is not available. +func SetupIPTables(ctx context.Context, uid int, proxyPort int) error { + if !isIPTablesAvailable() { + log.Printf("WARN: iptables not available, skipping redirect setup (cap_net_admin may be denied)") + return nil + } + + cleanupIPTables(ctx) + + chain := "FORGE_SUPERVISOR" + uidStr := fmt.Sprintf("%d", uid) + portStr := fmt.Sprintf("%d", proxyPort) + + // REDIRECT target is only valid in the nat table + cmds := []struct { + name string + args []string + }{ + // Create custom chain in nat table + {"iptables", []string{"-t", "nat", "-N", chain}}, + // Match outgoing TCP from UID 1000, jump to custom chain + {"iptables", []string{"-t", "nat", "-A", "OUTPUT", "-m", "owner", "--uid-owner", uidStr, "-p", "tcp", "-j", chain}}, + // Redirect to proxy port + {"iptables", []string{"-t", "nat", "-A", chain, "-p", "tcp", "-j", "REDIRECT", "--to-port", portStr}}, + } + + for _, cmd := range cmds { + if err := runIPTables(ctx, cmd.name, cmd.args...); err != nil { + if strings.Contains(err.Error(), "Chain already exists") { + continue + } + log.Printf("WARN: iptables setup failed: %v", err) + return nil + } + } + + log.Printf("INFO: iptables nat redirect configured: UID %d -> port %d", uid, proxyPort) + return nil +} + +// isIPTablesAvailable checks if iptables exists and is executable. +func isIPTablesAvailable() bool { + ctx, cancel := context.WithTimeout(context.Background(), waitTimeout) + defer cancel() + cmd := exec.CommandContext(ctx, "iptables", "--version") + return cmd.Run() == nil +} + +// cleanupIPTables removes any existing FORGE_SUPERVISOR chain rules. +func cleanupIPTables(ctx context.Context) { + chain := "FORGE_SUPERVISOR" + uidStr := targetUID + + runIPTables(ctx, "iptables", "-t", "nat", "-F", chain) + runIPTables(ctx, "iptables", "-t", "nat", "-D", "OUTPUT", "-m", "owner", "--uid-owner", uidStr, "-p", "tcp", "-j", chain) + runIPTables(ctx, "iptables", "-t", "nat", "-X", chain) +} + +// runIPTables executes an iptables command with the given arguments. +func runIPTables(ctx context.Context, name string, args ...string) error { + cmd := exec.CommandContext(ctx, name, args...) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("iptables %v: %s: %w", args, string(out), err) + } + return nil +} diff --git a/forge-supervisor/main.go b/forge-supervisor/main.go new file mode 100644 index 0000000..8b68c41 --- /dev/null +++ b/forge-supervisor/main.go @@ -0,0 +1,105 @@ +package main + +import ( + "context" + "log" + "os" + "os/signal" + "strconv" + "syscall" + + "github.com/initializ/forge/forge-core/security" +) + +func main() { + log.SetFlags(0) + log.SetOutput(os.Stdout) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Load egress policy — path from env or default + policyPath := os.Getenv("FORGE_SUPERVISOR_POLICY_PATH") + if policyPath == "" { + policyPath = "/etc/forge/egress_allowlist.json" + } + + policy, err := LoadPolicy(policyPath) + if err != nil { + log.Fatalf("FATAL: failed to load policy from %q: %v", policyPath, err) + } + + // Create domain matcher + matcher := security.NewDomainMatcher(policy.Mode, policy.AllowedDomains) + + // Ports from env or defaults + proxyPort := 15001 + if p := os.Getenv("FORGE_SUPERVISOR_PROXY_PORT"); p != "" { + if v, err := strconv.Atoi(p); err == nil && v > 0 && v < 65536 { + proxyPort = v + } + } + healthPort := 15000 + if h := os.Getenv("FORGE_SUPERVISOR_HEALTH_PORT"); h != "" { + if v, err := strconv.Atoi(h); err == nil && v > 0 && v < 65536 { + healthPort = v + } + } + + // Set up iptables REDIRECT for UID 1000 — supervisor stays UID 0 + if err := SetupIPTables(ctx, 1000, proxyPort); err != nil { + log.Printf("WARNING: iptables setup failed (may lack CAP_NET_ADMIN): %v", err) + } + + // Start audit logger + audit := NewAuditLogger() + + // Start health endpoints + denialTracker := &DenialTracker{denials: []DenialEvent{}} + StartHealthEndpoints(denialTracker, healthPort) + + // Create transparent proxy + proxy := NewTransparentProxy(matcher, denialTracker, audit) + if err := proxy.Start(ctx, ":"+strconv.Itoa(proxyPort)); err != nil { + log.Fatalf("FATAL: failed to start proxy: %v", err) + } + + // NOTE: Do NOT drop privileges on the supervisor process. + // The supervisor runs as UID 0 so its own traffic is not redirected. + // Only the agent child process (exec.go) runs as UID 1000. + + // Fork/exec the agent process — runs as UID 1000 via exec.go + agentCmd := os.Args[1:] + if len(agentCmd) == 0 { + agentCmd = []string{"/bin/sh", "-l"} + } + + proc, err := ExecAgent(agentCmd) + if err != nil { + log.Fatalf("FATAL: failed to exec agent: %v", err) + } + + // Forward signals to agent + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGCHLD) + + for { + select { + case s := <-sigChan: + switch s { + case syscall.SIGCHLD: + var status syscall.WaitStatus + syscall.Wait4(proc.Pid, &status, 0, nil) + if status.Exited() { + audit.LogExitEvent(proc.Pid, status.ExitStatus()) + cancel() + return + } + default: + ForwardSignal(proc.Pid, s.(syscall.Signal)) + } + case <-ctx.Done(): + return + } + } +} diff --git a/forge-supervisor/policy.go b/forge-supervisor/policy.go new file mode 100644 index 0000000..fce6770 --- /dev/null +++ b/forge-supervisor/policy.go @@ -0,0 +1,39 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/initializ/forge/forge-core/security" +) + +// Policy represents the egress policy loaded from egress_allowlist.json +type Policy struct { + Mode security.EgressMode `json:"mode"` + AllowedDomains []string `json:"allowed_domains"` + AllowPrivateIPs bool `json:"allow_private_ips"` +} + +// LoadPolicy loads the egress policy from the specified JSON file. +func LoadPolicy(path string) (*Policy, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read policy file: %w", err) + } + + var policy Policy + if err := json.Unmarshal(data, &policy); err != nil { + return nil, fmt.Errorf("parse policy: %w", err) + } + + // Validate mode + switch policy.Mode { + case security.ModeAllowlist, security.ModeDenyAll, security.ModeDevOpen: + // Valid + default: + return nil, fmt.Errorf("invalid egress mode: %q", policy.Mode) + } + + return &policy, nil +} diff --git a/forge-supervisor/privdrop.go b/forge-supervisor/privdrop.go new file mode 100644 index 0000000..834880b --- /dev/null +++ b/forge-supervisor/privdrop.go @@ -0,0 +1,60 @@ +package main + +import ( + "fmt" + "log" + "os" + + "golang.org/x/sys/unix" +) + +// DropPrivileges drops root privileges to the specified UID/GID. +// It clears all supplementary groups, drops all capabilities, +// and sets PR_SET_NO_NEW_PRIVS to prevent privilege escalation. +func DropPrivileges(uid, gid int) error { + log.Printf("INFO: dropping privileges to UID %d, GID %d", uid, gid) + + // Clear supplementary groups before setgid (required on some systems) + if err := unix.Setgroups([]int{gid}); err != nil { + log.Printf("WARN: Setgroups: %v (continuing)", err) + } + + // Set no_new_privs before dropping privileges + if err := unix.Prctl(unix.PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0); err != nil { + log.Printf("WARN: PR_SET_NO_NEW_PRIVS: %v", err) + } + + // Drop all capabilities (bounding set) + if err := dropAllCapabilities(); err != nil { + log.Printf("WARN: drop capabilities: %v", err) + } + + // Set GID first (required by some systems) + if err := unix.Setgid(gid); err != nil { + return fmt.Errorf("setgid: %w", err) + } + + // Set UID + if err := unix.Setuid(uid); err != nil { + return fmt.Errorf("setuid: %w", err) + } + + // Verify the drop + currentUID := os.Getuid() + currentGID := os.Getgid() + if currentUID != uid || currentGID != gid { + return fmt.Errorf("privilege drop failed: UID=%d GID=%d (wanted %d %d)", currentUID, currentGID, uid, gid) + } + + log.Printf("INFO: privileges dropped successfully") + return nil +} + +// dropAllCapabilities clears the capability bounding set. +func dropAllCapabilities() error { + // Clear the capability bounding set (limits what can be raised) + for cap := 0; cap <= 40; cap++ { + unix.Prctl(unix.PR_CAPBSET_DROP, uintptr(cap), 0, 0, 0) + } + return nil +} diff --git a/forge-supervisor/proxy.go b/forge-supervisor/proxy.go new file mode 100644 index 0000000..a3e894e --- /dev/null +++ b/forge-supervisor/proxy.go @@ -0,0 +1,293 @@ +package main + +import ( + "context" + "encoding/binary" + "fmt" + "io" + "log" + "net" + "sync" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/unix" + + "github.com/initializ/forge/forge-core/security" +) + +// TransparentProxy intercepts redirected TCP traffic, extracts the target hostname +// via SNI or HTTP Host header, checks against the domain matcher, and either +// forwards or denies the connection. +type TransparentProxy struct { + listener net.Listener + matcher *security.DomainMatcher + denialTracker *DenialTracker + audit *AuditLogger + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +// NewTransparentProxy creates a new transparent TCP proxy. +func NewTransparentProxy(matcher *security.DomainMatcher, denialTracker *DenialTracker, audit *AuditLogger) *TransparentProxy { + ctx, cancel := context.WithCancel(context.Background()) + return &TransparentProxy{ + matcher: matcher, + denialTracker: denialTracker, + audit: audit, + ctx: ctx, + cancel: cancel, + } +} + +// Start begins listening and accepting connections. +func (p *TransparentProxy) Start(ctx context.Context, addr string) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("proxy listen: %w", err) + } + p.listener = ln + + log.Printf("INFO: transparent proxy listening on %s", ln.Addr().String()) + + p.wg.Add(1) + go p.acceptLoop() + + return nil +} + +// Stop gracefully shuts down the proxy. +func (p *TransparentProxy) Stop() error { + p.cancel() + if p.listener != nil { + p.listener.Close() + } + p.wg.Wait() + return nil +} + +func (p *TransparentProxy) acceptLoop() { + defer p.wg.Done() + + for { + conn, err := p.listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + log.Printf("ERROR: proxy accept: %v", err) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(conn) + } +} + +func (p *TransparentProxy) handleConnection(client net.Conn) { + defer p.wg.Done() + defer client.Close() + + // Get original destination from redirected connection + origAddr, err := getOriginalDst(client) + if err != nil { + log.Printf("ERROR: get original dst: %v", err) + return + } + + origTCPAddr, ok := origAddr.(*net.TCPAddr) + if !ok { + log.Printf("ERROR: unexpected addr type: %T", origAddr) + return + } + + // Extract host and any consumed bytes (for replay) + host, consumed, allowed := p.extractAndCheck(client, origTCPAddr) + + if !allowed { + p.denialTracker.Add(DenialEvent{ + Timestamp: time.Now().UTC(), + Host: host, + Port: origTCPAddr.Port, + }) + + p.audit.Log(&AuditEvent{ + Timestamp: time.Now().UTC(), + Action: "denied", + Host: host, + Port: origTCPAddr.Port, + }) + + fmt.Fprintf(client, "HTTP/1.1 403 Forbidden\r\n\r\n") + return + } + + // Log allowed connection + p.audit.Log(&AuditEvent{ + Timestamp: time.Now().UTC(), + Action: "allowed", + Host: host, + Port: origTCPAddr.Port, + }) + + // Dial original destination + upstream, err := net.DialTimeout("tcp", origTCPAddr.String(), 10*time.Second) + if err != nil { + log.Printf("ERROR: dial upstream %s: %v", origTCPAddr, err) + return + } + defer upstream.Close() + + // Replay consumed bytes to upstream, then relay the rest + p.relay(client, upstream, consumed) +} + +// extractAndCheck peeks at the first bytes to determine TLS vs HTTP, +// extracts the hostname, and checks against the allowlist. +// Returns host, consumed bytes (for replay), and whether it's allowed. +func (p *TransparentProxy) extractAndCheck(conn net.Conn, addr *net.TCPAddr) (string, []byte, bool) { + // Peek at first 5 bytes (TLS record header max) + firstBytes := make([]byte, 5) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + + n, err := conn.Read(firstBytes) + if err != nil && err != io.EOF { + log.Printf("ERROR: peek bytes: %v", err) + return fmt.Sprintf("%s:%d", addr.IP, addr.Port), nil, false + } + + var host string + var consumed []byte + + if n >= 3 && firstBytes[0] == 0x16 && firstBytes[1] == 0x03 { + // TLS ClientHello + sniBytes, sniHost := ExtractSNIFromClientHello(firstBytes[:n], conn) + if sniHost != "" { + host = sniHost + consumed = append(consumed, sniBytes...) + } else { + host = fmt.Sprintf("%s:%d", addr.IP, addr.Port) + } + } else { + // HTTP — read Host header, replaying consumed bytes first + httpBytes, httpHost := ExtractHTTPHost(firstBytes[:n], conn) + if httpHost != "" { + host = httpHost + consumed = httpBytes + } else { + host = fmt.Sprintf("%s:%d", addr.IP, addr.Port) + } + } + + // Validate host against SSRF bypass patterns + if err := security.ValidateHostIP(host); err != nil { + log.Printf("WARN: invalid host format %q: %v", host, err) + return host, consumed, false + } + + allowed := p.matcher.IsAllowed(host) + return host, consumed, allowed +} + +// relay copies data between client and upstream bidirectionally. +// consumed bytes are written to upstream first (they were peeked during extraction). +func (p *TransparentProxy) relay(client, upstream net.Conn, consumed []byte) { + var wg sync.WaitGroup + wg.Add(2) + + // upstream <- client (plain copy) + go func() { + defer wg.Done() + io.Copy(upstream, client) + upstream.Close() + }() + + // client <- upstream, but first write consumed bytes + go func() { + defer wg.Done() + // Send consumed bytes first (TLS ClientHello or HTTP request line+headers) + if len(consumed) > 0 { + upstreamCopy := &peekReader{r: upstream, peeked: consumed} + io.Copy(client, upstreamCopy) + // Then continue with remaining upstream data + io.Copy(client, upstream) + } else { + io.Copy(client, upstream) + } + client.Close() + }() + + wg.Wait() +} + +// peekReader wraps a reader and returns embedded bytes first, then reads from underlying. +type peekReader struct { + r net.Conn + peeked []byte + peekIdx int +} + +func (p *peekReader) Read(b []byte) (int, error) { + if p.peekIdx < len(p.peeked) { + n := copy(b, p.peeked[p.peekIdx:]) + p.peekIdx += n + return n, nil + } + return p.r.Read(b) +} + +// getOriginalDst retrieves the original destination for an iptables-redirected connection. +func getOriginalDst(conn net.Conn) (net.Addr, error) { + sc, ok := conn.(syscall.Conn) + if !ok { + return nil, fmt.Errorf("not a syscall.Conn") + } + + rc, err := sc.SyscallConn() + if err != nil { + return nil, fmt.Errorf("SyscallConn: %w", err) + } + + var origAddr unix.RawSockaddrAny + var sockLen int32 = int32(unix.SizeofSockaddrAny) + + err = rc.Control(func(fd uintptr) { + ret, _, _ := unix.Syscall6( + unix.SYS_GETSOCKOPT, + fd, + uintptr(unix.IPPROTO_IP), + uintptr(unix.SO_ORIGINAL_DST), + uintptr(unsafe.Pointer(&origAddr)), + uintptr(unsafe.Pointer(&sockLen)), + 0) + if ret != 0 { + err = syscall.Errno(ret) + } + }) + if err != nil { + return nil, fmt.Errorf("getsockopt SO_ORIGINAL_DST: %v", err) + } + + if origAddr.Addr.Family == unix.AF_INET { + ptr4 := (*unix.RawSockaddrInet4)(unsafe.Pointer(&origAddr)) + // Port is stored in network byte order — convert to host byte order + port := int(binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&ptr4.Port))[:])) + return &net.TCPAddr{ + IP: net.IP(ptr4.Addr[:]), + Port: port, + }, nil + } else if origAddr.Addr.Family == unix.AF_INET6 { + ptr6 := (*unix.RawSockaddrInet6)(unsafe.Pointer(&origAddr)) + port := int(binary.BigEndian.Uint16((*[2]byte)(unsafe.Pointer(&ptr6.Port))[:])) + return &net.TCPAddr{ + IP: net.IP(ptr6.Addr[:]), + Port: port, + }, nil + } + + return nil, fmt.Errorf("unknown address family: %d", origAddr.Addr.Family) +} diff --git a/forge-supervisor/sni.go b/forge-supervisor/sni.go new file mode 100644 index 0000000..5aa9cb1 --- /dev/null +++ b/forge-supervisor/sni.go @@ -0,0 +1,115 @@ +package main + +import ( + "io" + "net" +) + +// ExtractSNIFromClientHello extracts the Server Name Indication (SNI) +// from a TLS ClientHello message. It peeks at the ClientHello without +// terminating TLS. Returns all consumed bytes (for replay) and hostname. +func ExtractSNIFromClientHello(firstBytes []byte, conn net.Conn) ([]byte, string) { + // TLS record header: content_type (1) + version (2) + length (2) + // For ClientHello, content_type = 0x16 (handshake) + if len(firstBytes) < 5 { + return firstBytes, "" + } + + // Read full handshake header: type (1) + length (3) + handshakeHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, handshakeHeader); err != nil { + return firstBytes, "" + } + + // handshake type should be 0x01 (ClientHello) + if handshakeHeader[0] != 0x01 { + return firstBytes, "" + } + + // Read ClientHello body + // We read a reasonable chunk and scan for SNI extension (type 0x0000) + body := make([]byte, 1024) + n, _ := io.ReadFull(conn, body) + + // TLS ClientHello structure after random: + // offset 0-33: client_version (2) + random (32) + // offset 34: session_id_length (1) + // offset 35+: session_id (variable) + // Then: cipher_suites_length (2), cipher_suites (variable) + // Then: compression_methods_length (1), compression_methods (variable) + // Then: extensions_length (2) + // Then: extensions (variable) + + offset := 34 + + if offset >= n { + return firstBytes, "" + } + + // session_id_length (1) + sessionIDLen := int(body[offset]) + offset += 1 + sessionIDLen + + if offset+2 >= n { + return firstBytes, "" + } + + // cipher_suites_length (2) + cipherLen := int(body[offset])<<8 | int(body[offset+1]) + offset += 2 + cipherLen + + if offset >= n { + return firstBytes, "" + } + + // compression_methods_length (1) + compressionLen := int(body[offset]) + offset += 1 + compressionLen + + if offset+2 >= n { + return firstBytes, "" + } + + // extensions_length (2) + extensionsLen := int(body[offset])<<8 | int(body[offset+1]) + offset += 2 + + if offset+extensionsLen > n { + return firstBytes, "" + } + + // Scan extensions for SNI (type 0x0000) + extensionsEnd := offset + extensionsLen + for offset+4 <= extensionsEnd { + extType := int(body[offset])<<8 | int(body[offset+1]) + extLen := int(body[offset+2])<<8 | int(body[offset+3]) + offset += 4 + + if extType == 0 { // SNI extension + // SNI extension_data structure: + // server_name_list_length (2 bytes) — total length of following list + // For each entry: + // name_type (1 byte) — 0x00 = hostname + // name_length (2 bytes) + // name (name_length bytes) + if offset+4 > extensionsEnd { + return firstBytes, "" + } + // name_length is at body[offset+2] and body[offset+3] + // (offset+0 = server_name_list_length high, + // offset+1 = server_name_list_length low, + // offset+2 = name_type, offset+3 = name_length high) + nameLen := int(body[offset+2])<<8 | int(body[offset+3]) + if offset+4+nameLen > extensionsEnd { + return firstBytes, "" + } + // name starts at offset+4 (after: list_len(2) + name_type(1) + name_len(2)) + consumed := append(handshakeHeader, body[:n]...) + return consumed, string(body[offset+4 : offset+4+nameLen]) + } + + offset += extLen + } + + return firstBytes, "" +} diff --git a/go.work b/go.work index 687ca41..ab85358 100644 --- a/go.work +++ b/go.work @@ -6,4 +6,5 @@ use ( ./forge-plugins ./forge-skills ./forge-ui + ./forge-supervisor )