Files
codex-agent-manager/internal/agents/writeback.go
2026-05-25 21:26:37 +08:00

363 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package agents
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync"
"syscall"
"time"
"codex-agent-manager/internal/codexhome"
)
var ErrWriteConflict = errors.New("目标文件已在校验后发生变化")
var safeAgentID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_-]*$`)
var writebackMu sync.Mutex
var writebackTestHookBeforeBackup func()
var writebackTestHookAfterBackup func()
type fileIdentity struct {
dev uint64
ino uint64
mode os.FileMode
}
type writeTarget struct {
path string
content []byte
mode os.FileMode
agentsIdentity fileIdentity
targetIdentity fileIdentity
}
func (s Store) ValidateDraft(id string, content string) (DraftValidation, error) {
target, err := s.readWriteTarget(id)
if err != nil {
return DraftValidation{}, err
}
result := DraftValidation{
TargetPath: target.path,
CurrentHash: hashBytes(target.content),
}
currentFields, currentErr := parseSimpleTOML(string(target.content))
draftFields, draftErr := parseSimpleTOML(content)
if draftErr != nil {
result.Valid = false
result.Errors = []string{draftErr.Error()}
result.Diff = simpleDiff(string(target.content), content)
return result, nil
}
result.Valid = true
result.Errors = []string{}
result.Diff = simpleDiff(string(target.content), content)
if currentErr == nil {
result.FieldChanges = changedFields(currentFields, draftFields)
}
return result, nil
}
func (s Store) WriteDraft(id string, content string, expectedHash string) (WriteResult, error) {
writebackMu.Lock()
defer writebackMu.Unlock()
validation, err := s.ValidateDraft(id, content)
if err != nil {
return WriteResult{}, err
}
if !validation.Valid {
return WriteResult{}, errors.New(strings.Join(validation.Errors, ""))
}
if expectedHash == "" || validation.CurrentHash != expectedHash {
return WriteResult{}, ErrWriteConflict
}
target, err := s.readWriteTarget(id)
if err != nil {
return WriteResult{}, err
}
if hashBytes(target.content) != expectedHash {
return WriteResult{}, ErrWriteConflict
}
if writebackTestHookBeforeBackup != nil {
writebackTestHookBeforeBackup()
}
if _, err := s.verifyWriteTarget(id, target, expectedHash); err != nil {
return WriteResult{}, err
}
backupPath, err := s.createBackup(target.path, target.content, target.mode)
if err != nil {
return WriteResult{}, err
}
if writebackTestHookAfterBackup != nil {
writebackTestHookAfterBackup()
}
if err := atomicWrite(target, []byte(content), func() error {
_, err := s.verifyWriteTarget(id, target, expectedHash)
return err
}); err != nil {
return WriteResult{}, err
}
return WriteResult{
Status: "written",
TargetPath: target.path,
BackupPath: backupPath,
CurrentHash: hashBytes([]byte(content)),
}, nil
}
func (s Store) readWriteTarget(id string) (writeTarget, error) {
if !safeAgentID.MatchString(id) {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsPath := filepath.Join(s.CodexHome, "agents")
agentsInfo, err := os.Lstat(agentsPath)
if err != nil {
return writeTarget{}, err
} else if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsIdentity, err := identityOf(agentsInfo)
if err != nil {
return writeTarget{}, err
}
fileName := id + ".toml"
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, fileName)
if err != nil {
return writeTarget{}, err
}
targetInfo, err := os.Lstat(targetPath)
if err != nil {
return writeTarget{}, err
}
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
targetIdentity, err := identityOf(targetInfo)
if err != nil {
return writeTarget{}, err
}
data, err := os.ReadFile(targetPath)
if err != nil {
return writeTarget{}, err
}
target := writeTarget{
path: targetPath,
content: data,
mode: targetInfo.Mode().Perm(),
agentsIdentity: agentsIdentity,
targetIdentity: targetIdentity,
}
if _, err := s.verifyWriteTarget(id, target, hashBytes(data)); err != nil {
return writeTarget{}, err
}
return target, nil
}
func (s Store) verifyWriteTarget(id string, expected writeTarget, expectedHash string) (writeTarget, error) {
current, err := s.readWriteTargetUnchecked(id)
if err != nil {
if os.IsNotExist(err) {
return writeTarget{}, ErrWriteConflict
}
return writeTarget{}, err
}
if current.path != expected.path ||
current.agentsIdentity != expected.agentsIdentity ||
current.targetIdentity != expected.targetIdentity {
return writeTarget{}, ErrWriteConflict
}
if hashBytes(current.content) != expectedHash {
return writeTarget{}, ErrWriteConflict
}
return current, nil
}
func (s Store) readWriteTargetUnchecked(id string) (writeTarget, error) {
if !safeAgentID.MatchString(id) {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsPath := filepath.Join(s.CodexHome, "agents")
agentsInfo, err := os.Lstat(agentsPath)
if err != nil {
return writeTarget{}, err
}
if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsIdentity, err := identityOf(agentsInfo)
if err != nil {
return writeTarget{}, err
}
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, id+".toml")
if err != nil {
return writeTarget{}, err
}
targetInfo, err := os.Lstat(targetPath)
if err != nil {
return writeTarget{}, err
}
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
targetIdentity, err := identityOf(targetInfo)
if err != nil {
return writeTarget{}, err
}
data, err := os.ReadFile(targetPath)
if err != nil {
return writeTarget{}, err
}
return writeTarget{
path: targetPath,
content: data,
mode: targetInfo.Mode().Perm(),
agentsIdentity: agentsIdentity,
targetIdentity: targetIdentity,
}, nil
}
func identityOf(info os.FileInfo) (fileIdentity, error) {
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return fileIdentity{}, errors.New("无法确认文件身份")
}
return fileIdentity{
dev: uint64(stat.Dev),
ino: uint64(stat.Ino),
mode: info.Mode(),
}, nil
}
func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode) (string, error) {
backupPath := fmt.Sprintf("%s.bak-%s", targetPath, time.Now().UTC().Format("20060102T150405.000000000Z"))
file, err := os.OpenFile(backupPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, mode)
if err != nil {
return "", err
}
if _, err := file.Write(content); err != nil {
_ = file.Close()
_ = os.Remove(backupPath)
return "", err
}
if err := file.Close(); err != nil {
_ = os.Remove(backupPath)
return "", err
}
return backupPath, nil
}
func atomicWrite(target writeTarget, content []byte, beforeRename func() error) error {
dir := filepath.Dir(target.path)
base := filepath.Base(target.path)
tmp, err := os.CreateTemp(dir, "."+base+".tmp-*")
if err != nil {
return err
}
tmpPath := tmp.Name()
defer func() {
_ = os.Remove(tmpPath)
}()
if err := tmp.Chmod(target.mode); err != nil {
_ = tmp.Close()
return err
}
if _, err := tmp.Write(content); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Close(); err != nil {
return err
}
if beforeRename != nil {
if err := beforeRename(); err != nil {
return err
}
}
return os.Rename(tmpPath, target.path)
}
func hashBytes(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
func changedFields(before map[string]string, after map[string]string) []FieldChange {
keys := map[string]bool{}
for key := range before {
keys[key] = true
}
for key := range after {
keys[key] = true
}
ordered := make([]string, 0, len(keys))
for key := range keys {
ordered = append(ordered, key)
}
sort.Strings(ordered)
changes := make([]FieldChange, 0)
for _, key := range ordered {
if before[key] == after[key] {
continue
}
changes = append(changes, FieldChange{Field: key, Before: before[key], After: after[key]})
}
return changes
}
func simpleDiff(before string, after string) string {
if before == after {
return "无差异\n"
}
beforeLines := splitLines(before)
afterLines := splitLines(after)
prefix := 0
for prefix < len(beforeLines) && prefix < len(afterLines) && beforeLines[prefix] == afterLines[prefix] {
prefix++
}
suffix := 0
for suffix < len(beforeLines)-prefix &&
suffix < len(afterLines)-prefix &&
beforeLines[len(beforeLines)-1-suffix] == afterLines[len(afterLines)-1-suffix] {
suffix++
}
var b strings.Builder
b.WriteString("--- current\n+++ draft\n@@\n")
for _, line := range beforeLines[prefix : len(beforeLines)-suffix] {
b.WriteString("-")
b.WriteString(line)
b.WriteString("\n")
}
for _, line := range afterLines[prefix : len(afterLines)-suffix] {
b.WriteString("+")
b.WriteString(line)
b.WriteString("\n")
}
return b.String()
}
func splitLines(input string) []string {
trimmed := strings.TrimSuffix(input, "\n")
if trimmed == "" {
return []string{}
}
return strings.Split(trimmed, "\n")
}