fix: harden agent writeback safety
This commit is contained in:
@@ -10,6 +10,8 @@ import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"codex-agent-manager/internal/codexhome"
|
||||
@@ -19,29 +21,47 @@ var ErrWriteConflict = errors.New("目标文件已在校验后发生变化")
|
||||
|
||||
var safeAgentID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_-]*$`)
|
||||
|
||||
var writebackMu sync.Mutex
|
||||
var writebackTestHookBeforeBackup func()
|
||||
var writebackTestHookAfterBackup func()
|
||||
|
||||
type fileIdentity struct {
|
||||
dev uint64
|
||||
ino uint64
|
||||
mode os.FileMode
|
||||
}
|
||||
|
||||
type writeTarget struct {
|
||||
path string
|
||||
content []byte
|
||||
mode os.FileMode
|
||||
agentsIdentity fileIdentity
|
||||
targetIdentity fileIdentity
|
||||
}
|
||||
|
||||
func (s Store) ValidateDraft(id string, content string) (DraftValidation, error) {
|
||||
targetPath, current, _, err := s.readWriteTarget(id)
|
||||
target, err := s.readWriteTarget(id)
|
||||
if err != nil {
|
||||
return DraftValidation{}, err
|
||||
}
|
||||
|
||||
result := DraftValidation{
|
||||
TargetPath: targetPath,
|
||||
CurrentHash: hashBytes(current),
|
||||
TargetPath: target.path,
|
||||
CurrentHash: hashBytes(target.content),
|
||||
}
|
||||
|
||||
currentFields, currentErr := parseSimpleTOML(string(current))
|
||||
currentFields, currentErr := parseSimpleTOML(string(target.content))
|
||||
draftFields, draftErr := parseSimpleTOML(content)
|
||||
if draftErr != nil {
|
||||
result.Valid = false
|
||||
result.Errors = []string{draftErr.Error()}
|
||||
result.Diff = simpleDiff(string(current), content)
|
||||
result.Diff = simpleDiff(string(target.content), content)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result.Valid = true
|
||||
result.Errors = []string{}
|
||||
result.Diff = simpleDiff(string(current), content)
|
||||
result.Diff = simpleDiff(string(target.content), content)
|
||||
if currentErr == nil {
|
||||
result.FieldChanges = changedFields(currentFields, draftFields)
|
||||
}
|
||||
@@ -49,6 +69,9 @@ func (s Store) ValidateDraft(id string, content string) (DraftValidation, error)
|
||||
}
|
||||
|
||||
func (s Store) WriteDraft(id string, content string, expectedHash string) (WriteResult, error) {
|
||||
writebackMu.Lock()
|
||||
defer writebackMu.Unlock()
|
||||
|
||||
validation, err := s.ValidateDraft(id, content)
|
||||
if err != nil {
|
||||
return WriteResult{}, err
|
||||
@@ -60,58 +83,164 @@ func (s Store) WriteDraft(id string, content string, expectedHash string) (Write
|
||||
return WriteResult{}, ErrWriteConflict
|
||||
}
|
||||
|
||||
targetPath, current, mode, err := s.readWriteTarget(id)
|
||||
target, err := s.readWriteTarget(id)
|
||||
if err != nil {
|
||||
return WriteResult{}, err
|
||||
}
|
||||
if hashBytes(current) != expectedHash {
|
||||
if hashBytes(target.content) != expectedHash {
|
||||
return WriteResult{}, ErrWriteConflict
|
||||
}
|
||||
|
||||
backupPath, err := s.createBackup(targetPath, current, mode)
|
||||
if writebackTestHookBeforeBackup != nil {
|
||||
writebackTestHookBeforeBackup()
|
||||
}
|
||||
if _, err := s.verifyWriteTarget(id, target, expectedHash); err != nil {
|
||||
return WriteResult{}, err
|
||||
}
|
||||
backupPath, err := s.createBackup(target.path, target.content, target.mode)
|
||||
if err != nil {
|
||||
return WriteResult{}, err
|
||||
}
|
||||
if err := atomicWrite(targetPath, []byte(content), mode); err != nil {
|
||||
if writebackTestHookAfterBackup != nil {
|
||||
writebackTestHookAfterBackup()
|
||||
}
|
||||
if err := atomicWrite(target, []byte(content), func() error {
|
||||
_, err := s.verifyWriteTarget(id, target, expectedHash)
|
||||
return err
|
||||
}); err != nil {
|
||||
return WriteResult{}, err
|
||||
}
|
||||
|
||||
return WriteResult{
|
||||
Status: "written",
|
||||
TargetPath: targetPath,
|
||||
TargetPath: target.path,
|
||||
BackupPath: backupPath,
|
||||
CurrentHash: hashBytes([]byte(content)),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s Store) readWriteTarget(id string) (string, []byte, os.FileMode, error) {
|
||||
func (s Store) readWriteTarget(id string) (writeTarget, error) {
|
||||
if !safeAgentID.MatchString(id) {
|
||||
return "", nil, 0, codexhome.ErrForbiddenPath
|
||||
return writeTarget{}, codexhome.ErrForbiddenPath
|
||||
}
|
||||
agentsPath := filepath.Join(s.CodexHome, "agents")
|
||||
if info, err := os.Lstat(agentsPath); err != nil {
|
||||
return "", nil, 0, err
|
||||
} else if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
|
||||
return "", nil, 0, codexhome.ErrForbiddenPath
|
||||
agentsInfo, err := os.Lstat(agentsPath)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
} else if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
|
||||
return writeTarget{}, codexhome.ErrForbiddenPath
|
||||
}
|
||||
agentsIdentity, err := identityOf(agentsInfo)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
|
||||
fileName := id + ".toml"
|
||||
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, fileName)
|
||||
if err != nil {
|
||||
return "", nil, 0, err
|
||||
return writeTarget{}, err
|
||||
}
|
||||
info, err := os.Lstat(targetPath)
|
||||
targetInfo, err := os.Lstat(targetPath)
|
||||
if err != nil {
|
||||
return "", nil, 0, err
|
||||
return writeTarget{}, err
|
||||
}
|
||||
if info.Mode()&os.ModeSymlink != 0 || !info.Mode().IsRegular() {
|
||||
return "", nil, 0, codexhome.ErrForbiddenPath
|
||||
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
|
||||
return writeTarget{}, codexhome.ErrForbiddenPath
|
||||
}
|
||||
targetIdentity, err := identityOf(targetInfo)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
data, err := os.ReadFile(targetPath)
|
||||
if err != nil {
|
||||
return "", nil, 0, err
|
||||
return writeTarget{}, err
|
||||
}
|
||||
return targetPath, data, info.Mode().Perm(), nil
|
||||
target := writeTarget{
|
||||
path: targetPath,
|
||||
content: data,
|
||||
mode: targetInfo.Mode().Perm(),
|
||||
agentsIdentity: agentsIdentity,
|
||||
targetIdentity: targetIdentity,
|
||||
}
|
||||
if _, err := s.verifyWriteTarget(id, target, hashBytes(data)); err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
return target, nil
|
||||
}
|
||||
|
||||
func (s Store) verifyWriteTarget(id string, expected writeTarget, expectedHash string) (writeTarget, error) {
|
||||
current, err := s.readWriteTargetUnchecked(id)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return writeTarget{}, ErrWriteConflict
|
||||
}
|
||||
return writeTarget{}, err
|
||||
}
|
||||
if current.path != expected.path ||
|
||||
current.agentsIdentity != expected.agentsIdentity ||
|
||||
current.targetIdentity != expected.targetIdentity {
|
||||
return writeTarget{}, ErrWriteConflict
|
||||
}
|
||||
if hashBytes(current.content) != expectedHash {
|
||||
return writeTarget{}, ErrWriteConflict
|
||||
}
|
||||
return current, nil
|
||||
}
|
||||
|
||||
func (s Store) readWriteTargetUnchecked(id string) (writeTarget, error) {
|
||||
if !safeAgentID.MatchString(id) {
|
||||
return writeTarget{}, codexhome.ErrForbiddenPath
|
||||
}
|
||||
agentsPath := filepath.Join(s.CodexHome, "agents")
|
||||
agentsInfo, err := os.Lstat(agentsPath)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
if agentsInfo.Mode()&os.ModeSymlink != 0 || !agentsInfo.IsDir() {
|
||||
return writeTarget{}, codexhome.ErrForbiddenPath
|
||||
}
|
||||
agentsIdentity, err := identityOf(agentsInfo)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, id+".toml")
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
targetInfo, err := os.Lstat(targetPath)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
if targetInfo.Mode()&os.ModeSymlink != 0 || !targetInfo.Mode().IsRegular() {
|
||||
return writeTarget{}, codexhome.ErrForbiddenPath
|
||||
}
|
||||
targetIdentity, err := identityOf(targetInfo)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
data, err := os.ReadFile(targetPath)
|
||||
if err != nil {
|
||||
return writeTarget{}, err
|
||||
}
|
||||
return writeTarget{
|
||||
path: targetPath,
|
||||
content: data,
|
||||
mode: targetInfo.Mode().Perm(),
|
||||
agentsIdentity: agentsIdentity,
|
||||
targetIdentity: targetIdentity,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func identityOf(info os.FileInfo) (fileIdentity, error) {
|
||||
stat, ok := info.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
return fileIdentity{}, errors.New("无法确认文件身份")
|
||||
}
|
||||
return fileIdentity{
|
||||
dev: uint64(stat.Dev),
|
||||
ino: uint64(stat.Ino),
|
||||
mode: info.Mode(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode) (string, error) {
|
||||
@@ -132,9 +261,9 @@ func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode)
|
||||
return backupPath, nil
|
||||
}
|
||||
|
||||
func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
|
||||
dir := filepath.Dir(targetPath)
|
||||
base := filepath.Base(targetPath)
|
||||
func atomicWrite(target writeTarget, content []byte, beforeRename func() error) error {
|
||||
dir := filepath.Dir(target.path)
|
||||
base := filepath.Base(target.path)
|
||||
tmp, err := os.CreateTemp(dir, "."+base+".tmp-*")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -144,7 +273,7 @@ func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
|
||||
_ = os.Remove(tmpPath)
|
||||
}()
|
||||
|
||||
if err := tmp.Chmod(mode); err != nil {
|
||||
if err := tmp.Chmod(target.mode); err != nil {
|
||||
_ = tmp.Close()
|
||||
return err
|
||||
}
|
||||
@@ -155,7 +284,12 @@ func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmpPath, targetPath)
|
||||
if beforeRename != nil {
|
||||
if err := beforeRename(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return os.Rename(tmpPath, target.path)
|
||||
}
|
||||
|
||||
func hashBytes(data []byte) string {
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"codex-agent-manager/internal/codexhome"
|
||||
)
|
||||
|
||||
func TestValidateInvalidTOMLReturnsInvalidAndDoesNotWrite(t *testing.T) {
|
||||
@@ -78,6 +80,62 @@ func TestWriteExpectedHashMismatchRejectsAndLeavesOriginal(t *testing.T) {
|
||||
assertFileContent(t, target, `name = "用户已改"`+"\n")
|
||||
}
|
||||
|
||||
func TestWriteRejectsAgentsDirectoryReplacementBeforeBackup(t *testing.T) {
|
||||
store, target := writebackFixture(t, `name = "旧名称"`+"\n")
|
||||
validation, err := store.ValidateDraft("backend", `name = "新名称"`+"\n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
root := filepath.Dir(filepath.Dir(target))
|
||||
agentsDir := filepath.Join(root, "agents")
|
||||
realAgentsDir := filepath.Join(root, "agents-real")
|
||||
externalDir := filepath.Join(root, "external")
|
||||
if err := os.MkdirAll(externalDir, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(externalDir, "backend.toml"), []byte(`name = "外部"`+"\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
writebackTestHookBeforeBackup = func() {
|
||||
if err := os.Rename(agentsDir, realAgentsDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := os.Symlink(externalDir, agentsDir); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
defer func() { writebackTestHookBeforeBackup = nil }()
|
||||
|
||||
_, err = store.WriteDraft("backend", `name = "新名称"`+"\n", validation.CurrentHash)
|
||||
if !errors.Is(err, ErrWriteConflict) && !errors.Is(err, codexhome.ErrForbiddenPath) {
|
||||
t.Fatalf("expected directory replacement to be rejected, got %v", err)
|
||||
}
|
||||
assertFileContent(t, filepath.Join(realAgentsDir, "backend.toml"), `name = "旧名称"`+"\n")
|
||||
assertFileContent(t, filepath.Join(externalDir, "backend.toml"), `name = "外部"`+"\n")
|
||||
}
|
||||
|
||||
func TestWriteRejectsTargetChangeAfterBackup(t *testing.T) {
|
||||
store, target := writebackFixture(t, `name = "旧名称"`+"\n")
|
||||
validation, err := store.ValidateDraft("backend", `name = "新名称"`+"\n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
writebackTestHookAfterBackup = func() {
|
||||
if err := os.WriteFile(target, []byte(`name = "用户已改"`+"\n"), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
defer func() { writebackTestHookAfterBackup = nil }()
|
||||
|
||||
_, err = store.WriteDraft("backend", `name = "新名称"`+"\n", validation.CurrentHash)
|
||||
if !errors.Is(err, ErrWriteConflict) {
|
||||
t.Fatalf("expected post-backup target change to be rejected, got %v", err)
|
||||
}
|
||||
assertFileContent(t, target, `name = "用户已改"`+"\n")
|
||||
}
|
||||
|
||||
func TestWriteSuccessCreatesBackupAndAtomicallyWrites(t *testing.T) {
|
||||
store, target := writebackFixture(t, `name = "旧名称"`+"\n")
|
||||
validation, err := store.ValidateDraft("backend", `name = "新名称"`+"\n")
|
||||
|
||||
Reference in New Issue
Block a user