fix: harden agent writeback safety

This commit is contained in:
Yoilun
2026-05-25 21:26:37 +08:00
parent a01dd36fb0
commit d7b75a1112
6 changed files with 341 additions and 34 deletions

View File

@@ -10,6 +10,8 @@ import (
"regexp"
"sort"
"strings"
"sync"
"syscall"
"time"
"codex-agent-manager/internal/codexhome"
@@ -19,29 +21,47 @@ var ErrWriteConflict = errors.New("目标文件已在校验后发生变化")
var safeAgentID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_-]*$`)
var writebackMu sync.Mutex
var writebackTestHookBeforeBackup func()
var writebackTestHookAfterBackup func()
type fileIdentity struct {
dev uint64
ino uint64
mode os.FileMode
}
type writeTarget struct {
path string
content []byte
mode os.FileMode
agentsIdentity fileIdentity
targetIdentity fileIdentity
}
func (s Store) ValidateDraft(id string, content string) (DraftValidation, error) {
targetPath, current, _, err := s.readWriteTarget(id)
target, err := s.readWriteTarget(id)
if err != nil {
return DraftValidation{}, err
}
result := DraftValidation{
TargetPath: targetPath,
CurrentHash: hashBytes(current),
TargetPath: target.path,
CurrentHash: hashBytes(target.content),
}
currentFields, currentErr := parseSimpleTOML(string(current))
currentFields, currentErr := parseSimpleTOML(string(target.content))
draftFields, draftErr := parseSimpleTOML(content)
if draftErr != nil {
result.Valid = false
result.Errors = []string{draftErr.Error()}
result.Diff = simpleDiff(string(current), content)
result.Diff = simpleDiff(string(target.content), content)
return result, nil
}
result.Valid = true
result.Errors = []string{}
result.Diff = simpleDiff(string(current), content)
result.Diff = simpleDiff(string(target.content), content)
if currentErr == nil {
result.FieldChanges = changedFields(currentFields, draftFields)
}
@@ -49,6 +69,9 @@ func (s Store) ValidateDraft(id string, content string) (DraftValidation, error)
}
func (s Store) WriteDraft(id string, content string, expectedHash string) (WriteResult, error) {
writebackMu.Lock()
defer writebackMu.Unlock()
validation, err := s.ValidateDraft(id, content)
if err != nil {
return WriteResult{}, err
@@ -60,58 +83,164 @@ func (s Store) WriteDraft(id string, content string, expectedHash string) (Write
return WriteResult{}, ErrWriteConflict
}
targetPath, current, mode, err := s.readWriteTarget(id)
target, err := s.readWriteTarget(id)
if err != nil {
return WriteResult{}, err
}
if hashBytes(current) != expectedHash {
if hashBytes(target.content) != expectedHash {
return WriteResult{}, ErrWriteConflict
}
backupPath, err := s.createBackup(targetPath, current, mode)
if writebackTestHookBeforeBackup != nil {
writebackTestHookBeforeBackup()
}
if _, err := s.verifyWriteTarget(id, target, expectedHash); err != nil {
return WriteResult{}, err
}
backupPath, err := s.createBackup(target.path, target.content, target.mode)
if err != nil {
return WriteResult{}, err
}
if err := atomicWrite(targetPath, []byte(content), mode); err != nil {
if writebackTestHookAfterBackup != nil {
writebackTestHookAfterBackup()
}
if err := atomicWrite(target, []byte(content), func() error {
_, err := s.verifyWriteTarget(id, target, expectedHash)
return err
}); err != nil {
return WriteResult{}, err
}
return WriteResult{
Status: "written",
TargetPath: targetPath,
TargetPath: target.path,
BackupPath: backupPath,
CurrentHash: hashBytes([]byte(content)),
}, nil
}
func (s Store) readWriteTarget(id string) (string, []byte, os.FileMode, error) {
func (s Store) readWriteTarget(id string) (writeTarget, error) {
if !safeAgentID.MatchString(id) {
return "", nil, 0, codexhome.ErrForbiddenPath
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsPath := filepath.Join(s.CodexHome, "agents")
if info, err := os.Lstat(agentsPath); err != nil {
return "", nil, 0, err
} else if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
return "", nil, 0, codexhome.ErrForbiddenPath
agentsInfo, err := os.Lstat(agentsPath)
if err != nil {
return writeTarget{}, err
} else if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsIdentity, err := identityOf(agentsInfo)
if err != nil {
return writeTarget{}, err
}
fileName := id + ".toml"
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, fileName)
if err != nil {
return "", nil, 0, err
return writeTarget{}, err
}
info, err := os.Lstat(targetPath)
targetInfo, err := os.Lstat(targetPath)
if err != nil {
return "", nil, 0, err
return writeTarget{}, err
}
if info.Mode()&os.ModeSymlink != 0 || !info.Mode().IsRegular() {
return "", nil, 0, codexhome.ErrForbiddenPath
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
targetIdentity, err := identityOf(targetInfo)
if err != nil {
return writeTarget{}, err
}
data, err := os.ReadFile(targetPath)
if err != nil {
return "", nil, 0, err
return writeTarget{}, err
}
return targetPath, data, info.Mode().Perm(), nil
target := writeTarget{
path: targetPath,
content: data,
mode: targetInfo.Mode().Perm(),
agentsIdentity: agentsIdentity,
targetIdentity: targetIdentity,
}
if _, err := s.verifyWriteTarget(id, target, hashBytes(data)); err != nil {
return writeTarget{}, err
}
return target, nil
}
func (s Store) verifyWriteTarget(id string, expected writeTarget, expectedHash string) (writeTarget, error) {
current, err := s.readWriteTargetUnchecked(id)
if err != nil {
if os.IsNotExist(err) {
return writeTarget{}, ErrWriteConflict
}
return writeTarget{}, err
}
if current.path != expected.path ||
current.agentsIdentity != expected.agentsIdentity ||
current.targetIdentity != expected.targetIdentity {
return writeTarget{}, ErrWriteConflict
}
if hashBytes(current.content) != expectedHash {
return writeTarget{}, ErrWriteConflict
}
return current, nil
}
func (s Store) readWriteTargetUnchecked(id string) (writeTarget, error) {
if !safeAgentID.MatchString(id) {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsPath := filepath.Join(s.CodexHome, "agents")
agentsInfo, err := os.Lstat(agentsPath)
if err != nil {
return writeTarget{}, err
}
if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
agentsIdentity, err := identityOf(agentsInfo)
if err != nil {
return writeTarget{}, err
}
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, id+".toml")
if err != nil {
return writeTarget{}, err
}
targetInfo, err := os.Lstat(targetPath)
if err != nil {
return writeTarget{}, err
}
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
return writeTarget{}, codexhome.ErrForbiddenPath
}
targetIdentity, err := identityOf(targetInfo)
if err != nil {
return writeTarget{}, err
}
data, err := os.ReadFile(targetPath)
if err != nil {
return writeTarget{}, err
}
return writeTarget{
path: targetPath,
content: data,
mode: targetInfo.Mode().Perm(),
agentsIdentity: agentsIdentity,
targetIdentity: targetIdentity,
}, nil
}
func identityOf(info os.FileInfo) (fileIdentity, error) {
stat, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return fileIdentity{}, errors.New("无法确认文件身份")
}
return fileIdentity{
dev: uint64(stat.Dev),
ino: uint64(stat.Ino),
mode: info.Mode(),
}, nil
}
func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode) (string, error) {
@@ -132,9 +261,9 @@ func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode)
return backupPath, nil
}
func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
dir := filepath.Dir(targetPath)
base := filepath.Base(targetPath)
func atomicWrite(target writeTarget, content []byte, beforeRename func() error) error {
dir := filepath.Dir(target.path)
base := filepath.Base(target.path)
tmp, err := os.CreateTemp(dir, "."+base+".tmp-*")
if err != nil {
return err
@@ -144,7 +273,7 @@ func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
_ = os.Remove(tmpPath)
}()
if err := tmp.Chmod(mode); err != nil {
if err := tmp.Chmod(target.mode); err != nil {
_ = tmp.Close()
return err
}
@@ -155,7 +284,12 @@ func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
if err := tmp.Close(); err != nil {
return err
}
return os.Rename(tmpPath, targetPath)
if beforeRename != nil {
if err := beforeRename(); err != nil {
return err
}
}
return os.Rename(tmpPath, target.path)
}
func hashBytes(data []byte) string {

View File

@@ -6,6 +6,8 @@ import (
"path/filepath"
"strings"
"testing"
"codex-agent-manager/internal/codexhome"
)
func TestValidateInvalidTOMLReturnsInvalidAndDoesNotWrite(t *testing.T) {
@@ -78,6 +80,62 @@ func TestWriteExpectedHashMismatchRejectsAndLeavesOriginal(t *testing.T) {
assertFileContent(t, target, `name = "用户已改"`+"\n")
}
func TestWriteRejectsAgentsDirectoryReplacementBeforeBackup(t *testing.T) {
store, target := writebackFixture(t, `name = "旧名称"`+"\n")
validation, err := store.ValidateDraft("backend", `name = "新名称"`+"\n")
if err != nil {
t.Fatal(err)
}
root := filepath.Dir(filepath.Dir(target))
agentsDir := filepath.Join(root, "agents")
realAgentsDir := filepath.Join(root, "agents-real")
externalDir := filepath.Join(root, "external")
if err := os.MkdirAll(externalDir, 0o755); err != nil {
t.Fatal(err)
}
if err := os.WriteFile(filepath.Join(externalDir, "backend.toml"), []byte(`name = "外部"`+"\n"), 0o644); err != nil {
t.Fatal(err)
}
writebackTestHookBeforeBackup = func() {
if err := os.Rename(agentsDir, realAgentsDir); err != nil {
t.Fatal(err)
}
if err := os.Symlink(externalDir, agentsDir); err != nil {
t.Fatal(err)
}
}
defer func() { writebackTestHookBeforeBackup = nil }()
_, err = store.WriteDraft("backend", `name = "新名称"`+"\n", validation.CurrentHash)
if !errors.Is(err, ErrWriteConflict) && !errors.Is(err, codexhome.ErrForbiddenPath) {
t.Fatalf("expected directory replacement to be rejected, got %v", err)
}
assertFileContent(t, filepath.Join(realAgentsDir, "backend.toml"), `name = "旧名称"`+"\n")
assertFileContent(t, filepath.Join(externalDir, "backend.toml"), `name = "外部"`+"\n")
}
func TestWriteRejectsTargetChangeAfterBackup(t *testing.T) {
store, target := writebackFixture(t, `name = "旧名称"`+"\n")
validation, err := store.ValidateDraft("backend", `name = "新名称"`+"\n")
if err != nil {
t.Fatal(err)
}
writebackTestHookAfterBackup = func() {
if err := os.WriteFile(target, []byte(`name = "用户已改"`+"\n"), 0o644); err != nil {
t.Fatal(err)
}
}
defer func() { writebackTestHookAfterBackup = nil }()
_, err = store.WriteDraft("backend", `name = "新名称"`+"\n", validation.CurrentHash)
if !errors.Is(err, ErrWriteConflict) {
t.Fatalf("expected post-backup target change to be rejected, got %v", err)
}
assertFileContent(t, target, `name = "用户已改"`+"\n")
}
func TestWriteSuccessCreatesBackupAndAtomicallyWrites(t *testing.T) {
store, target := writebackFixture(t, `name = "旧名称"`+"\n")
validation, err := store.ValidateDraft("backend", `name = "新名称"`+"\n")