fix: harden agent writeback safety

This commit is contained in:
Yoilun
2026-05-25 21:26:37 +08:00
parent a01dd36fb0
commit d7b75a1112
6 changed files with 341 additions and 34 deletions

View File

@@ -10,6 +10,8 @@ import (
"regexp"
"sort"
"strings"
"sync"
"syscall"
"time"
"codex-agent-manager/internal/codexhome"
@@ -19,29 +21,47 @@ var ErrWriteConflict = errors.New("目标文件已在校验后发生变化")
var safeAgentID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_-]*$`)
var writebackMu sync.Mutex
var writebackTestHookBeforeBackup func()
var writebackTestHookAfterBackup func()
type fileIdentity struct {
dev uint64
ino uint64
mode os.FileMode
}
type writeTarget struct {
path string
content []byte
mode os.FileMode
agentsIdentity fileIdentity
targetIdentity fileIdentity
}
func (s Store) ValidateDraft(id string, content string) (DraftValidation, error) {
targetPath, current, _, err := s.readWriteTarget(id)
target, err := s.readWriteTarget(id)
if err != nil {
return DraftValidation{}, err
}
result := DraftValidation{
TargetPath: targetPath,
CurrentHash: hashBytes(current),
TargetPath: target.path,
CurrentHash: hashBytes(target.content),
}
currentFields, currentErr := parseSimpleTOML(string(current))
currentFields, currentErr := parseSimpleTOML(string(target.content))
draftFields, draftErr := parseSimpleTOML(content)
if draftErr != nil {
result.Valid = false
result.Errors = []string{draftErr.Error()}
result.Diff = simpleDiff(string(current), content)
result.Diff = simpleDiff(string(target.content), content)
return result, nil
}
result.Valid = true
result.Errors = []string{}
result.Diff = simpleDiff(string(current), content)
result.Diff = simpleDiff(string(target.content), content)
if currentErr == nil {
result.FieldChanges = changedFields(currentFields, draftFields)
}
@@ -49,6 +69,9 @@ func (s Store) ValidateDraft(id string, content string) (DraftValidation, error)
}
func (s Store) WriteDraft(id string, content string, expectedHash string) (WriteResult, error) {
writebackMu.Lock()
defer writebackMu.Unlock()
validation, err := s.ValidateDraft(id, content)
if err != nil {
return WriteResult{}, err
@@ -60,58 +83,164 @@ func (s Store) WriteDraft(id string, content string, expectedHash string) (Write
return WriteResult{}, ErrWriteConflict
}
targetPath, current, mode, err := s.readWriteTarget(id)
target, err := s.readWriteTarget(id)
if err != nil {
return WriteResult{}, err
}
if hashBytes(current) != expectedHash {
if hashBytes(target.content) != expectedHash {
return WriteResult{}, ErrWriteConflict
}
backupPath, err := s.createBackup(targetPath, current, mode)
if writebackTestHookBeforeBackup != nil {
writebackTestHookBeforeBackup()
}
if _, err := s.verifyWriteTarget(id, target, expectedHash); err != nil {
return WriteResult{}, err
}
backupPath, err := s.createBackup(target.path, target.content, target.mode)
if err != nil {
return WriteResult{}, err
}
if err := atomicWrite(targetPath, []byte(content), mode); err != nil {
if writebackTestHookAfterBackup != nil {
writebackTestHookAfterBackup()
}
if err := atomicWrite(target, []byte(content), func() error {
_, err := s.verifyWriteTarget(id, target, expectedHash)
return err
}); err != nil {
return WriteResult{}, err
}
return WriteResult{
Status: "written",
TargetPath: targetPath,
TargetPath: target.path,
BackupPath: backupPath,
CurrentHash: hashBytes([]byte(content)),
}, nil
}
func (s Store) readWriteTarget(id string) (string, []byte, os.FileMode, error) {
func (s Store) readWriteTarget(id string) (writeTarget, error) {
if !safeAgentID.MatchString(id) {
return "", nil, 0, codexhome.ErrForbiddenPath
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsPath := filepath.Join(s.CodexHome, "agents")
if info, err := os.Lstat(agentsPath); err != nil {
return "", nil, 0, err
} else if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
return "", nil, 0, codexhome.ErrForbiddenPath
agentsInfo, err := os.Lstat(agentsPath)
if err != nil {
return writeTarget{}, err
} else if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsIdentity, err := identityOf(agentsInfo)
if err != nil {
return writeTarget{}, err
}
fileName := id + ".toml"
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, fileName)
if err != nil {
return "", nil, 0, err
return writeTarget{}, err
}
info, err := os.Lstat(targetPath)
targetInfo, err := os.Lstat(targetPath)
if err != nil {
return "", nil, 0, err
return writeTarget{}, err
}
if info.Mode()&os.ModeSymlink != 0 || !info.Mode().IsRegular() {
return "", nil, 0, codexhome.ErrForbiddenPath
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
targetIdentity, err := identityOf(targetInfo)
if err != nil {
return writeTarget{}, err
}
data, err := os.ReadFile(targetPath)
if err != nil {
return "", nil, 0, err
return writeTarget{}, err
}
return targetPath, data, info.Mode().Perm(), nil
target := writeTarget{
path: targetPath,
content: data,
mode: targetInfo.Mode().Perm(),
agentsIdentity: agentsIdentity,
targetIdentity: targetIdentity,
}
if _, err := s.verifyWriteTarget(id, target, hashBytes(data)); err != nil {
return writeTarget{}, err
}
return target, nil
}
func (s Store) verifyWriteTarget(id string, expected writeTarget, expectedHash string) (writeTarget, error) {
current, err := s.readWriteTargetUnchecked(id)
if err != nil {
if os.IsNotExist(err) {
return writeTarget{}, ErrWriteConflict
}
return writeTarget{}, err
}
if current.path != expected.path ||
current.agentsIdentity != expected.agentsIdentity ||
current.targetIdentity != expected.targetIdentity {
return writeTarget{}, ErrWriteConflict
}
if hashBytes(current.content) != expectedHash {
return writeTarget{}, ErrWriteConflict
}
return current, nil
}
func (s Store) readWriteTargetUnchecked(id string) (writeTarget, error) {
if !safeAgentID.MatchString(id) {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsPath := filepath.Join(s.CodexHome, "agents")
agentsInfo, err := os.Lstat(agentsPath)
if err != nil {
return writeTarget{}, err
}
if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsIdentity, err := identityOf(agentsInfo)
if err != nil {
return writeTarget{}, err
}
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, id+".toml")
if err != nil {
return writeTarget{}, err
}
targetInfo, err := os.Lstat(targetPath)
if err != nil {
return writeTarget{}, err
}
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
targetIdentity, err := identityOf(targetInfo)
if err != nil {
return writeTarget{}, err
}
data, err := os.ReadFile(targetPath)
if err != nil {
return writeTarget{}, err
}
return writeTarget{
path: targetPath,
content: data,
mode: targetInfo.Mode().Perm(),
agentsIdentity: agentsIdentity,
targetIdentity: targetIdentity,
}, nil
}
func identityOf(info os.FileInfo) (fileIdentity, error) {
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return fileIdentity{}, errors.New("无法确认文件身份")
}
return fileIdentity{
dev: uint64(stat.Dev),
ino: uint64(stat.Ino),
mode: info.Mode(),
}, nil
}
func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode) (string, error) {
@@ -132,9 +261,9 @@ func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode)
return backupPath, nil
}
func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
dir := filepath.Dir(targetPath)
base := filepath.Base(targetPath)
func atomicWrite(target writeTarget, content []byte, beforeRename func() error) error {
dir := filepath.Dir(target.path)
base := filepath.Base(target.path)
tmp, err := os.CreateTemp(dir, "."+base+".tmp-*")
if err != nil {
return err
@@ -144,7 +273,7 @@ func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
_ = os.Remove(tmpPath)
}()
if err := tmp.Chmod(mode); err != nil {
if err := tmp.Chmod(target.mode); err != nil {
_ = tmp.Close()
return err
}
@@ -155,7 +284,12 @@ func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
if err := tmp.Close(); err != nil {
return err
}
return os.Rename(tmpPath, targetPath)
if beforeRename != nil {
if err := beforeRename(); err != nil {
return err
}
}
return os.Rename(tmpPath, target.path)
}
func hashBytes(data []byte) string {