Files
codex-agent-manager/internal/agents/writeback.go

438 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package agents
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"syscall"
"time"
"codex-agent-manager/internal/codexhome"
"golang.org/x/sys/unix"
)
var ErrWriteConflict = errors.New("目标文件已在校验后发生变化")
var safeAgentID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_-]*$`)
var writebackMu sync.Mutex
var writebackTestHookBeforeBackup func()
var writebackTestHookAfterVerifyBeforeBackup func()
var writebackTestHookAfterBackup func()
type fileIdentity struct {
dev uint64
ino uint64
}
type writeTarget struct {
path string
base string
content []byte
mode os.FileMode
agentsIdentity fileIdentity
targetIdentity fileIdentity
}
type agentsDirHandle struct {
fd int
path string
identity fileIdentity
}
func (s Store) ValidateDraft(id string, content string) (DraftValidation, error) {
target, err := s.readWriteTarget(id)
if err != nil {
return DraftValidation{}, err
}
result := DraftValidation{
TargetPath: target.path,
CurrentHash: hashBytes(target.content),
}
currentFields, currentErr := parseSimpleTOML(string(target.content))
draftFields, draftErr := parseSimpleTOML(content)
if draftErr != nil {
result.Valid = false
result.Errors = []string{draftErr.Error()}
result.Diff = simpleDiff(string(target.content), content)
return result, nil
}
result.Valid = true
result.Errors = []string{}
result.Diff = simpleDiff(string(target.content), content)
if currentErr == nil {
result.FieldChanges = changedFields(currentFields, draftFields)
}
return result, nil
}
func (s Store) WriteDraft(id string, content string, expectedHash string) (WriteResult, error) {
writebackMu.Lock()
defer writebackMu.Unlock()
validation, err := s.ValidateDraft(id, content)
if err != nil {
return WriteResult{}, err
}
if !validation.Valid {
return WriteResult{}, errors.New(strings.Join(validation.Errors, ""))
}
if expectedHash == "" || validation.CurrentHash != expectedHash {
return WriteResult{}, ErrWriteConflict
}
target, dir, err := s.openWriteTarget(id)
if err != nil {
return WriteResult{}, err
}
defer dir.close()
if hashBytes(target.content) != expectedHash {
return WriteResult{}, ErrWriteConflict
}
if writebackTestHookBeforeBackup != nil {
writebackTestHookBeforeBackup()
}
if _, err := s.verifyWriteTarget(dir, id, target, expectedHash); err != nil {
return WriteResult{}, err
}
if writebackTestHookAfterVerifyBeforeBackup != nil {
writebackTestHookAfterVerifyBeforeBackup()
}
if _, err := s.verifyWriteTarget(dir, id, target, expectedHash); err != nil {
return WriteResult{}, err
}
backupPath, err := s.createBackup(dir, target)
if err != nil {
return WriteResult{}, err
}
if writebackTestHookAfterBackup != nil {
writebackTestHookAfterBackup()
}
if err := atomicWrite(dir, target, []byte(content), func() error {
_, err := s.verifyWriteTarget(dir, id, target, expectedHash)
return err
}); err != nil {
return WriteResult{}, err
}
return WriteResult{
Status: "written",
TargetPath: target.path,
BackupPath: backupPath,
CurrentHash: hashBytes([]byte(content)),
}, nil
}
func (s Store) readWriteTarget(id string) (writeTarget, error) {
target, dir, err := s.openWriteTarget(id)
if err != nil {
return writeTarget{}, err
}
defer dir.close()
return target, nil
}
func (s Store) openWriteTarget(id string) (writeTarget, agentsDirHandle, error) {
if !safeAgentID.MatchString(id) {
return writeTarget{}, agentsDirHandle{}, codexhome.ErrForbiddenPath
}
dir, err := openAgentsDir(s.CodexHome)
if err != nil {
return writeTarget{}, agentsDirHandle{}, err
}
target, err := s.readTargetFromDir(dir, id)
if err != nil {
dir.close()
return writeTarget{}, agentsDirHandle{}, err
}
if err := ensureAgentsPathStillMatches(dir); err != nil {
dir.close()
return writeTarget{}, agentsDirHandle{}, err
}
return target, dir, nil
}
func openAgentsDir(home string) (agentsDirHandle, error) {
path := filepath.Join(home, "agents")
fd, err := unix.Open(path, unix.O_RDONLY|unix.O_DIRECTORY|unix.O_CLOEXEC|unix.O_NOFOLLOW, 0)
if err != nil {
if errors.Is(err, unix.ELOOP) || errors.Is(err, unix.ENOTDIR) {
return agentsDirHandle{}, codexhome.ErrForbiddenPath
}
return agentsDirHandle{}, err
}
var stat unix.Stat_t
if err := unix.Fstat(fd, &stat); err != nil {
_ = unix.Close(fd)
return agentsDirHandle{}, err
}
dir := agentsDirHandle{fd: fd, path: path, identity: identityOfUnix(stat)}
if err := ensureAgentsPathStillMatches(dir); err != nil {
_ = unix.Close(fd)
return agentsDirHandle{}, err
}
return dir, nil
}
func (d agentsDirHandle) close() {
if d.fd >= 0 {
_ = unix.Close(d.fd)
}
}
func ensureAgentsPathStillMatches(dir agentsDirHandle) error {
info, err := os.Lstat(dir.path)
if err != nil {
if os.IsNotExist(err) {
return ErrWriteConflict
}
return err
}
if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
return ErrWriteConflict
}
identity, err := identityOf(info)
if err != nil {
return err
}
if identity != dir.identity {
return ErrWriteConflict
}
return nil
}
func (s Store) readTargetFromDir(dir agentsDirHandle, id string) (writeTarget, error) {
if !safeAgentID.MatchString(id) {
return writeTarget{}, codexhome.ErrForbiddenPath
}
base := id + ".toml"
if _, err := codexhome.ResolveAgentTOML(s.CodexHome, base); err != nil {
return writeTarget{}, err
}
fd, err := unix.Openat(dir.fd, base, unix.O_RDONLY|unix.O_CLOEXEC|unix.O_NOFOLLOW, 0)
if err != nil {
if errors.Is(err, unix.ELOOP) {
return writeTarget{}, codexhome.ErrForbiddenPath
}
return writeTarget{}, err
}
var stat unix.Stat_t
if err := unix.Fstat(fd, &stat); err != nil {
_ = unix.Close(fd)
return writeTarget{}, err
}
if stat.Mode&unix.S_IFMT != unix.S_IFREG {
_ = unix.Close(fd)
return writeTarget{}, codexhome.ErrForbiddenPath
}
data, err := readAllFromFD(fd, base)
if err != nil {
return writeTarget{}, err
}
return writeTarget{
path: filepath.Join(dir.path, base),
base: base,
content: data,
mode: os.FileMode(stat.Mode & 0o777),
agentsIdentity: dir.identity,
targetIdentity: identityOfUnix(stat),
}, nil
}
func readAllFromFD(fd int, name string) ([]byte, error) {
file := os.NewFile(uintptr(fd), name)
if file == nil {
_ = unix.Close(fd)
return nil, errors.New("无法打开目标文件")
}
defer file.Close()
return io.ReadAll(file)
}
func (s Store) verifyWriteTarget(dir agentsDirHandle, id string, expected writeTarget, expectedHash string) (writeTarget, error) {
if err := ensureAgentsPathStillMatches(dir); err != nil {
return writeTarget{}, err
}
current, err := s.readTargetFromDir(dir, id)
if err != nil {
if os.IsNotExist(err) {
return writeTarget{}, ErrWriteConflict
}
return writeTarget{}, err
}
if current.path != expected.path ||
current.base != expected.base ||
current.agentsIdentity != expected.agentsIdentity ||
current.targetIdentity != expected.targetIdentity {
return writeTarget{}, ErrWriteConflict
}
if hashBytes(current.content) != expectedHash {
return writeTarget{}, ErrWriteConflict
}
return current, nil
}
func identityOf(info os.FileInfo) (fileIdentity, error) {
switch stat := info.Sys().(type) {
case *unix.Stat_t:
return identityOfUnix(*stat), nil
case *syscall.Stat_t:
return fileIdentity{
dev: uint64(stat.Dev),
ino: uint64(stat.Ino),
}, nil
default:
return fileIdentity{}, errors.New("无法确认文件身份")
}
}
func identityOfUnix(stat unix.Stat_t) fileIdentity {
return fileIdentity{
dev: uint64(stat.Dev),
ino: uint64(stat.Ino),
}
}
func (s Store) createBackup(dir agentsDirHandle, target writeTarget) (string, error) {
if err := ensureAgentsPathStillMatches(dir); err != nil {
return "", err
}
backupName := fmt.Sprintf("%s.bak-%s", target.base, time.Now().UTC().Format("20060102T150405.000000000Z"))
fd, err := unix.Openat(dir.fd, backupName, unix.O_WRONLY|unix.O_CREAT|unix.O_EXCL|unix.O_CLOEXEC|unix.O_NOFOLLOW, uint32(target.mode))
if err != nil {
return "", err
}
file := os.NewFile(uintptr(fd), backupName)
if file == nil {
_ = unix.Close(fd)
_ = unix.Unlinkat(dir.fd, backupName, 0)
return "", errors.New("无法创建备份")
}
if _, err := file.Write(target.content); err != nil {
_ = file.Close()
_ = unix.Unlinkat(dir.fd, backupName, 0)
return "", err
}
if err := file.Close(); err != nil {
_ = unix.Unlinkat(dir.fd, backupName, 0)
return "", err
}
return filepath.Join(dir.path, backupName), nil
}
func atomicWrite(dir agentsDirHandle, target writeTarget, content []byte, beforeRename func() error) error {
tmpName := fmt.Sprintf(".%s.tmp-%d-%d", target.base, os.Getpid(), time.Now().UnixNano())
fd, err := unix.Openat(dir.fd, tmpName, unix.O_WRONLY|unix.O_CREAT|unix.O_EXCL|unix.O_CLOEXEC|unix.O_NOFOLLOW, uint32(target.mode))
if err != nil {
return err
}
defer func() {
_ = unix.Unlinkat(dir.fd, tmpName, 0)
}()
if err := unix.Fchmod(fd, uint32(target.mode)); err != nil {
_ = unix.Close(fd)
return err
}
file := os.NewFile(uintptr(fd), tmpName)
if file == nil {
_ = unix.Close(fd)
return errors.New("无法创建临时文件")
}
if _, err := file.Write(content); err != nil {
_ = file.Close()
return err
}
if err := file.Close(); err != nil {
return err
}
if beforeRename != nil {
if err := beforeRename(); err != nil {
return err
}
}
return unix.Renameat(dir.fd, tmpName, dir.fd, target.base)
}
func hashBytes(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
func changedFields(before map[string]string, after map[string]string) []FieldChange {
keys := map[string]bool{}
for key := range before {
keys[key] = true
}
for key := range after {
keys[key] = true
}
ordered := make([]string, 0, len(keys))
for key := range keys {
ordered = append(ordered, key)
}
sort.Strings(ordered)
changes := make([]FieldChange, 0)
for _, key := range ordered {
if before[key] == after[key] {
continue
}
changes = append(changes, FieldChange{Field: key, Before: before[key], After: after[key]})
}
return changes
}
func simpleDiff(before string, after string) string {
if before == after {
return "无差异\n"
}
beforeLines := splitLines(before)
afterLines := splitLines(after)
prefix := 0
for prefix < len(beforeLines) && prefix < len(afterLines) && beforeLines[prefix] == afterLines[prefix] {
prefix++
}
suffix := 0
for suffix < len(beforeLines)-prefix &&
suffix < len(afterLines)-prefix &&
beforeLines[len(beforeLines)-1-suffix] == afterLines[len(afterLines)-1-suffix] {
suffix++
}
var b strings.Builder
b.WriteString("--- current\n+++ draft\n@@\n")
for _, line := range beforeLines[prefix : len(beforeLines)-suffix] {
b.WriteString("-")
b.WriteString(line)
b.WriteString("\n")
}
for _, line := range afterLines[prefix : len(afterLines)-suffix] {
b.WriteString("+")
b.WriteString(line)
b.WriteString("\n")
}
return b.String()
}
func splitLines(input string) []string {
trimmed := strings.TrimSuffix(input, "\n")
if trimmed == "" {
return []string{}
}
return strings.Split(trimmed, "\n")
}