Files
codex-agent-manager/internal/agents/writeback.go
2026-05-25 21:06:32 +08:00

229 lines
5.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package agents
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"time"
"codex-agent-manager/internal/codexhome"
)
var ErrWriteConflict = errors.New("目标文件已在校验后发生变化")
var safeAgentID = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9_-]*$`)
func (s Store) ValidateDraft(id string, content string) (DraftValidation, error) {
targetPath, current, _, err := s.readWriteTarget(id)
if err != nil {
return DraftValidation{}, err
}
result := DraftValidation{
TargetPath: targetPath,
CurrentHash: hashBytes(current),
}
currentFields, currentErr := parseSimpleTOML(string(current))
draftFields, draftErr := parseSimpleTOML(content)
if draftErr != nil {
result.Valid = false
result.Errors = []string{draftErr.Error()}
result.Diff = simpleDiff(string(current), content)
return result, nil
}
result.Valid = true
result.Errors = []string{}
result.Diff = simpleDiff(string(current), content)
if currentErr == nil {
result.FieldChanges = changedFields(currentFields, draftFields)
}
return result, nil
}
func (s Store) WriteDraft(id string, content string, expectedHash string) (WriteResult, error) {
validation, err := s.ValidateDraft(id, content)
if err != nil {
return WriteResult{}, err
}
if !validation.Valid {
return WriteResult{}, errors.New(strings.Join(validation.Errors, ""))
}
if expectedHash == "" || validation.CurrentHash != expectedHash {
return WriteResult{}, ErrWriteConflict
}
targetPath, current, mode, err := s.readWriteTarget(id)
if err != nil {
return WriteResult{}, err
}
if hashBytes(current) != expectedHash {
return WriteResult{}, ErrWriteConflict
}
backupPath, err := s.createBackup(targetPath, current, mode)
if err != nil {
return WriteResult{}, err
}
if err := atomicWrite(targetPath, []byte(content), mode); err != nil {
return WriteResult{}, err
}
return WriteResult{
Status: "written",
TargetPath: targetPath,
BackupPath: backupPath,
CurrentHash: hashBytes([]byte(content)),
}, nil
}
func (s Store) readWriteTarget(id string) (string, []byte, os.FileMode, error) {
if !safeAgentID.MatchString(id) {
return "", nil, 0, codexhome.ErrForbiddenPath
}
agentsPath := filepath.Join(s.CodexHome, "agents")
if info, err := os.Lstat(agentsPath); err != nil {
return "", nil, 0, err
} else if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
return "", nil, 0, codexhome.ErrForbiddenPath
}
fileName := id + ".toml"
targetPath, err := codexhome.ResolveAgentTOML(s.CodexHome, fileName)
if err != nil {
return "", nil, 0, err
}
info, err := os.Lstat(targetPath)
if err != nil {
return "", nil, 0, err
}
if info.Mode()&os.ModeSymlink != 0 || !info.Mode().IsRegular() {
return "", nil, 0, codexhome.ErrForbiddenPath
}
data, err := os.ReadFile(targetPath)
if err != nil {
return "", nil, 0, err
}
return targetPath, data, info.Mode().Perm(), nil
}
func (s Store) createBackup(targetPath string, content []byte, mode os.FileMode) (string, error) {
backupPath := fmt.Sprintf("%s.bak-%s", targetPath, time.Now().UTC().Format("20060102T150405.000000000Z"))
file, err := os.OpenFile(backupPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, mode)
if err != nil {
return "", err
}
if _, err := file.Write(content); err != nil {
_ = file.Close()
_ = os.Remove(backupPath)
return "", err
}
if err := file.Close(); err != nil {
_ = os.Remove(backupPath)
return "", err
}
return backupPath, nil
}
func atomicWrite(targetPath string, content []byte, mode os.FileMode) error {
dir := filepath.Dir(targetPath)
base := filepath.Base(targetPath)
tmp, err := os.CreateTemp(dir, "."+base+".tmp-*")
if err != nil {
return err
}
tmpPath := tmp.Name()
defer func() {
_ = os.Remove(tmpPath)
}()
if err := tmp.Chmod(mode); err != nil {
_ = tmp.Close()
return err
}
if _, err := tmp.Write(content); err != nil {
_ = tmp.Close()
return err
}
if err := tmp.Close(); err != nil {
return err
}
return os.Rename(tmpPath, targetPath)
}
func hashBytes(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
func changedFields(before map[string]string, after map[string]string) []FieldChange {
keys := map[string]bool{}
for key := range before {
keys[key] = true
}
for key := range after {
keys[key] = true
}
ordered := make([]string, 0, len(keys))
for key := range keys {
ordered = append(ordered, key)
}
sort.Strings(ordered)
changes := make([]FieldChange, 0)
for _, key := range ordered {
if before[key] == after[key] {
continue
}
changes = append(changes, FieldChange{Field: key, Before: before[key], After: after[key]})
}
return changes
}
func simpleDiff(before string, after string) string {
if before == after {
return "无差异\n"
}
beforeLines := splitLines(before)
afterLines := splitLines(after)
prefix := 0
for prefix < len(beforeLines) && prefix < len(afterLines) && beforeLines[prefix] == afterLines[prefix] {
prefix++
}
suffix := 0
for suffix < len(beforeLines)-prefix &&
suffix < len(afterLines)-prefix &&
beforeLines[len(beforeLines)-1-suffix] == afterLines[len(afterLines)-1-suffix] {
suffix++
}
var b strings.Builder
b.WriteString("--- current\n+++ draft\n@@\n")
for _, line := range beforeLines[prefix : len(beforeLines)-suffix] {
b.WriteString("-")
b.WriteString(line)
b.WriteString("\n")
}
for _, line := range afterLines[prefix : len(afterLines)-suffix] {
b.WriteString("+")
b.WriteString(line)
b.WriteString("\n")
}
return b.String()
}
func splitLines(input string) []string {
trimmed := strings.TrimSuffix(input, "\n")
if trimmed == "" {
return []string{}
}
return strings.Split(trimmed, "\n")
}