package agents import ( "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "os" "path/filepath" "regexp" "sort" "strings" "sync" "syscall" "time" "codex-agent-manager/internal/codexhome" "golang.org/x/sys/unix" ) var ErrWriteConflict = errors.New("目标文件已在校验后发生变化") var safeAgentID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_-]*$`) var writebackMu sync.Mutex var writebackTestHookBeforeBackup func() var writebackTestHookAfterVerifyBeforeBackup func() var writebackTestHookAfterBackup func() type fileIdentity struct { dev uint64 ino uint64 } type writeTarget struct { path string base string content []byte mode os.FileMode agentsIdentity fileIdentity targetIdentity fileIdentity } type agentsDirHandle struct { fd int path string identity fileIdentity } func (s Store) ValidateDraft(id string, content string) (DraftValidation, error) { target, err := s.readWriteTarget(id) if err != nil { return DraftValidation{}, err } result := DraftValidation{ TargetPath: target.path, CurrentHash: hashBytes(target.content), } currentFields, currentErr := parseSimpleTOML(string(target.content)) draftFields, draftErr := parseSimpleTOML(content) if draftErr != nil { result.Valid = false result.Errors = []string{draftErr.Error()} result.Diff = simpleDiff(string(target.content), content) return result, nil } result.Valid = true result.Errors = []string{} result.Diff = simpleDiff(string(target.content), content) if currentErr == nil { result.FieldChanges = changedFields(currentFields, draftFields) } return result, nil } func (s Store) WriteDraft(id string, content string, expectedHash string) (WriteResult, error) { writebackMu.Lock() defer writebackMu.Unlock() validation, err := s.ValidateDraft(id, content) if err != nil { return WriteResult{}, err } if !validation.Valid { return WriteResult{}, errors.New(strings.Join(validation.Errors, ";")) } if expectedHash == "" || validation.CurrentHash != expectedHash { return WriteResult{}, ErrWriteConflict } target, dir, err := s.openWriteTarget(id) if err != nil { return WriteResult{}, err } defer dir.close() if hashBytes(target.content) != expectedHash { return WriteResult{}, ErrWriteConflict } if writebackTestHookBeforeBackup != nil { writebackTestHookBeforeBackup() } if _, err := s.verifyWriteTarget(dir, id, target, expectedHash); err != nil { return WriteResult{}, err } if writebackTestHookAfterVerifyBeforeBackup != nil { writebackTestHookAfterVerifyBeforeBackup() } if _, err := s.verifyWriteTarget(dir, id, target, expectedHash); err != nil { return WriteResult{}, err } backupPath, err := s.createBackup(dir, target) if err != nil { return WriteResult{}, err } if writebackTestHookAfterBackup != nil { writebackTestHookAfterBackup() } if err := atomicWrite(dir, target, []byte(content), func() error { _, err := s.verifyWriteTarget(dir, id, target, expectedHash) return err }); err != nil { return WriteResult{}, err } return WriteResult{ Status: "written", TargetPath: target.path, BackupPath: backupPath, CurrentHash: hashBytes([]byte(content)), }, nil } func (s Store) readWriteTarget(id string) (writeTarget, error) { target, dir, err := s.openWriteTarget(id) if err != nil { return writeTarget{}, err } defer dir.close() return target, nil } func (s Store) openWriteTarget(id string) (writeTarget, agentsDirHandle, error) { if !safeAgentID.MatchString(id) { return writeTarget{}, agentsDirHandle{}, codexhome.ErrForbiddenPath } dir, err := openAgentsDir(s.CodexHome) if err != nil { return writeTarget{}, agentsDirHandle{}, err } target, err := s.readTargetFromDir(dir, id) if err != nil { dir.close() return writeTarget{}, agentsDirHandle{}, err } if err := ensureAgentsPathStillMatches(dir); err != nil { dir.close() return writeTarget{}, agentsDirHandle{}, err } return target, dir, nil } func openAgentsDir(home string) (agentsDirHandle, error) { path := filepath.Join(home, "agents") fd, err := unix.Open(path, unix.O_RDONLY|unix.O_DIRECTORY|unix.O_CLOEXEC|unix.O_NOFOLLOW, 0) if err != nil { if errors.Is(err, unix.ELOOP) || errors.Is(err, unix.ENOTDIR) { return agentsDirHandle{}, codexhome.ErrForbiddenPath } return agentsDirHandle{}, err } var stat unix.Stat_t if err := unix.Fstat(fd, &stat); err != nil { _ = unix.Close(fd) return agentsDirHandle{}, err } dir := agentsDirHandle{fd: fd, path: path, identity: identityOfUnix(stat)} if err := ensureAgentsPathStillMatches(dir); err != nil { _ = unix.Close(fd) return agentsDirHandle{}, err } return dir, nil } func (d agentsDirHandle) close() { if d.fd >= 0 { _ = unix.Close(d.fd) } } func ensureAgentsPathStillMatches(dir agentsDirHandle) error { info, err := os.Lstat(dir.path) if err != nil { if os.IsNotExist(err) { return ErrWriteConflict } return err } if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() { return ErrWriteConflict } identity, err := identityOf(info) if err != nil { return err } if identity != dir.identity { return ErrWriteConflict } return nil } func (s Store) readTargetFromDir(dir agentsDirHandle, id string) (writeTarget, error) { if !safeAgentID.MatchString(id) { return writeTarget{}, codexhome.ErrForbiddenPath } base := id + ".toml" if _, err := codexhome.ResolveAgentTOML(s.CodexHome, base); err != nil { return writeTarget{}, err } fd, err := unix.Openat(dir.fd, base, unix.O_RDONLY|unix.O_CLOEXEC|unix.O_NOFOLLOW, 0) if err != nil { if errors.Is(err, unix.ELOOP) { return writeTarget{}, codexhome.ErrForbiddenPath } return writeTarget{}, err } var stat unix.Stat_t if err := unix.Fstat(fd, &stat); err != nil { _ = unix.Close(fd) return writeTarget{}, err } if stat.Mode&unix.S_IFMT != unix.S_IFREG { _ = unix.Close(fd) return writeTarget{}, codexhome.ErrForbiddenPath } data, err := readAllFromFD(fd, base) if err != nil { return writeTarget{}, err } return writeTarget{ path: filepath.Join(dir.path, base), base: base, content: data, mode: os.FileMode(stat.Mode & 0o777), agentsIdentity: dir.identity, targetIdentity: identityOfUnix(stat), }, nil } func readAllFromFD(fd int, name string) ([]byte, error) { file := os.NewFile(uintptr(fd), name) if file == nil { _ = unix.Close(fd) return nil, errors.New("无法打开目标文件") } defer file.Close() return io.ReadAll(file) } func (s Store) verifyWriteTarget(dir agentsDirHandle, id string, expected writeTarget, expectedHash string) (writeTarget, error) { if err := ensureAgentsPathStillMatches(dir); err != nil { return writeTarget{}, err } current, err := s.readTargetFromDir(dir, id) if err != nil { if os.IsNotExist(err) { return writeTarget{}, ErrWriteConflict } return writeTarget{}, err } if current.path != expected.path || current.base != expected.base || current.agentsIdentity != expected.agentsIdentity || current.targetIdentity != expected.targetIdentity { return writeTarget{}, ErrWriteConflict } if hashBytes(current.content) != expectedHash { return writeTarget{}, ErrWriteConflict } return current, nil } func identityOf(info os.FileInfo) (fileIdentity, error) { switch stat := info.Sys().(type) { case *unix.Stat_t: return identityOfUnix(*stat), nil case *syscall.Stat_t: return fileIdentity{ dev: uint64(stat.Dev), ino: uint64(stat.Ino), }, nil default: return fileIdentity{}, errors.New("无法确认文件身份") } } func identityOfUnix(stat unix.Stat_t) fileIdentity { return fileIdentity{ dev: uint64(stat.Dev), ino: uint64(stat.Ino), } } func (s Store) createBackup(dir agentsDirHandle, target writeTarget) (string, error) { if err := ensureAgentsPathStillMatches(dir); err != nil { return "", err } backupName := fmt.Sprintf("%s.bak-%s", target.base, time.Now().UTC().Format("20060102T150405.000000000Z")) fd, err := unix.Openat(dir.fd, backupName, unix.O_WRONLY|unix.O_CREAT|unix.O_EXCL|unix.O_CLOEXEC|unix.O_NOFOLLOW, uint32(target.mode)) if err != nil { return "", err } file := os.NewFile(uintptr(fd), backupName) if file == nil { _ = unix.Close(fd) _ = unix.Unlinkat(dir.fd, backupName, 0) return "", errors.New("无法创建备份") } if _, err := file.Write(target.content); err != nil { _ = file.Close() _ = unix.Unlinkat(dir.fd, backupName, 0) return "", err } if err := file.Close(); err != nil { _ = unix.Unlinkat(dir.fd, backupName, 0) return "", err } return filepath.Join(dir.path, backupName), nil } func atomicWrite(dir agentsDirHandle, target writeTarget, content []byte, beforeRename func() error) error { tmpName := fmt.Sprintf(".%s.tmp-%d-%d", target.base, os.Getpid(), time.Now().UnixNano()) fd, err := unix.Openat(dir.fd, tmpName, unix.O_WRONLY|unix.O_CREAT|unix.O_EXCL|unix.O_CLOEXEC|unix.O_NOFOLLOW, uint32(target.mode)) if err != nil { return err } defer func() { _ = unix.Unlinkat(dir.fd, tmpName, 0) }() if err := unix.Fchmod(fd, uint32(target.mode)); err != nil { _ = unix.Close(fd) return err } file := os.NewFile(uintptr(fd), tmpName) if file == nil { _ = unix.Close(fd) return errors.New("无法创建临时文件") } if _, err := file.Write(content); err != nil { _ = file.Close() return err } if err := file.Close(); err != nil { return err } if beforeRename != nil { if err := beforeRename(); err != nil { return err } } return unix.Renameat(dir.fd, tmpName, dir.fd, target.base) } func hashBytes(data []byte) string { sum := sha256.Sum256(data) return hex.EncodeToString(sum[:]) } func changedFields(before map[string]string, after map[string]string) []FieldChange { keys := map[string]bool{} for key := range before { keys[key] = true } for key := range after { keys[key] = true } ordered := make([]string, 0, len(keys)) for key := range keys { ordered = append(ordered, key) } sort.Strings(ordered) changes := make([]FieldChange, 0) for _, key := range ordered { if before[key] == after[key] { continue } changes = append(changes, FieldChange{Field: key, Before: before[key], After: after[key]}) } return changes } func simpleDiff(before string, after string) string { if before == after { return "无差异\n" } beforeLines := splitLines(before) afterLines := splitLines(after) prefix := 0 for prefix < len(beforeLines) && prefix < len(afterLines) && beforeLines[prefix] == afterLines[prefix] { prefix++ } suffix := 0 for suffix < len(beforeLines)-prefix && suffix < len(afterLines)-prefix && beforeLines[len(beforeLines)-1-suffix] == afterLines[len(afterLines)-1-suffix] { suffix++ } var b strings.Builder b.WriteString("--- current\n+++ draft\n@@\n") for _, line := range beforeLines[prefix : len(beforeLines)-suffix] { b.WriteString("-") b.WriteString(line) b.WriteString("\n") } for _, line := range afterLines[prefix : len(afterLines)-suffix] { b.WriteString("+") b.WriteString(line) b.WriteString("\n") } return b.String() } func splitLines(input string) []string { trimmed := strings.TrimSuffix(input, "\n") if trimmed == "" { return []string{} } return strings.Split(trimmed, "\n") }