package main

import (
	"bytes"
	"encoding/base64"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"os"
	"path/filepath"
	"strings"

	"git.sr.ht/~charles/rq/builtins"
	rqio "git.sr.ht/~charles/rq/io"
	"git.sr.ht/~charles/rq/version"

	"github.com/open-policy-agent/opa/v1/ast"
	opabundle "github.com/open-policy-agent/opa/v1/bundle"
	"github.com/open-policy-agent/opa/v1/loader"

	"github.com/mattn/go-isatty"
)

// This is used for -T/--test to cause rq to exit with a nonzero exit code in
// main().
var forceNonzeroExit = false

// This is done in this way to permit mocking stdin/stdout during testing.
var stdin io.Reader = os.Stdin
var stdout io.Writer = os.Stdout

// CapabilitiesCmd implements the "rq version" subcommand.
type CapabilitiesCmd struct {
}

// Run implements the main method for the "rq capabilities" subcommand.
func (c *CapabilitiesCmd) Run(globals *Globals) error {
	// This code is largely based on the implementation of the `opa
	// capabilities` command located here:
	// https://github.com/open-policy-agent/opa/blob/main/cmd/capabilities.go

	caps := ast.CapabilitiesForThisVersion()

	bs, err := json.MarshalIndent(caps, "", "  ")
	if err != nil {
		return err
	}

	_, err = stdout.Write(bs)
	if err != nil {
		return err
	}

	return nil
}

// VersionCmd implements the "rq version" subcommand.
type VersionCmd struct {
}

// Run implements the main method for the "rq version" subcommand.
func (v *VersionCmd) Run(globals *Globals) error {
	vi := version.Version()
	vd, err := json.Marshal(vi)
	if err != nil {
		return err
	}

	// We pass in the version info as the input, using the base64 handler.
	// This is a bit clunky, but performance isn't critical here, and I
	// want to avoid mutating the stdin/stdout globals.
	inputMap := map[string]interface{}{
		"format":    "base64",
		"file_path": "",
		"options": map[string]interface{}{
			"base64.data": base64.StdEncoding.EncodeToString(vd),
		},
	}

	inputData, err := json.Marshal(inputMap)
	if err != nil {
		return err
	}

	common := Common{}
	common.Input = string(inputData)
	common.InputFormat = ""

	// By default, we output in sh syntax, since that's easy to read and
	// easy to parse as well. The user can override this with -o of course.
	if common.OutputFormat == "" {
		common.OutputFormat = "yaml"
		common.NoColor = true
	}

	qc := QueryCmd{
		Common: &common,
		Query:  "input",
	}

	return qc.Run(globals)
}

// ListFormatsCmd implements the "rq list-formats" subcommand.
type ListFormatsCmd struct {
}

// Run implements the main methods for the "rq list-formats" subcommand.
func (l *ListFormatsCmd) Run(globals *Globals) error {
	if _, err := fmt.Fprintf(stdout, "Input Formats:\n"); err != nil {
		return err
	}
	for _, h := range rqio.ListInputHandlers() {
		if _, err := fmt.Fprintf(stdout, "\t%s\n", h); err != nil {
			return err
		}
	}
	if _, err := fmt.Fprintf(stdout, "\nOutput Formats:\n"); err != nil {
		return err
	}
	for _, h := range rqio.ListOutputHandlers() {
		if _, err := fmt.Fprintf(stdout, "\t%s\n", h); err != nil {
			return err
		}
	}
	_, err := fmt.Fprintf(stdout, `
NOTE: the 'raw' output format is the same as 'json', except syntax highlighting
and pretty printing are disabled, and if the result is a string, it will be
printed without quotes. If the output is a list of primitive types, they will
be printed one per line without quotes. This can be useful when using rq as a
filter for other shell tools.
`)
	if err != nil {
		return nil
	}

	return nil
}

// ScriptCmd implements the "rq script" subcommand.
//
// Note that Args is just a dummy to make the fact args can be passed show up
// in the help text. The Args field is ignored by this function.
type ScriptCmd struct {
	*Common
	File string   `help:"Rego Script" required:"" arg:"" default:""`
	Args []string `help:"Arguments to expose to the script via rq.args()." arg:"" optional:""`
}

// Run implements the main method for the "rq script" subcommand.
func (s *ScriptCmd) Run(globals *Globals) error {
	if s.File == "" {
		return errors.New("No input file supplied")
	}

	builtins.Args = s.Args

	sp, err := filepath.Abs(s.File)
	if err != nil {
		return err
	}
	builtins.ScriptPath = sp

	scriptBytes, err := os.ReadFile(s.File)
	if err != nil {
		return err
	}

	// Check if the script defines a package, if not, inject `package
	// script`.
	scriptPackage := "script"
	foundPackage := false
	for _, line := range strings.Split(string(scriptBytes), "\n") {
		line = strings.TrimSpace(line)
		split := strings.Split(line, " ")
		if len(split) < 2 {
			continue
		}
		if split[0] == "package" {
			foundPackage = true
			scriptPackage = split[1]
			break
		}
	}
	if !foundPackage {
		b := strings.Builder{}
		b.WriteString("package ")
		b.WriteString(scriptPackage)
		b.WriteString("\n")
		b.WriteString(string(scriptBytes))
		scriptBytes = []byte(b.String())
	}

	defaultQuery := "data." + scriptPackage

	directives, err := parseScriptConfig(string(scriptBytes))
	if err != nil {
		return err
	}

	query, _ := getScalar(directives, "query", defaultQuery)
	s.Common.Input, _ = getScalar(directives, "input", s.Common.Input)
	s.Common.InPlace, _ = getScalar(directives, "in-place", s.Common.InPlace)
	s.Common.InputFormat, _ = getScalar(directives, "input-format", s.Common.InputFormat)
	s.Common.Output, _ = getScalar(directives, "output", s.Common.Output)
	s.Common.OutputFormat, _ = getScalar(directives, "output-format", s.Common.OutputFormat)
	s.Common.NoColor, _ = getBool(directives, "no-color", s.Common.NoColor)
	s.Common.ForceColor, _ = getBool(directives, "force-color", s.Common.ForceColor)
	s.Common.Ugly, _ = getBool(directives, "ugly", s.Common.Ugly)
	s.Common.Indent, _ = getScalar(directives, "indent", s.Common.Indent)
	s.Common.Style, _ = getScalar(directives, "style", s.Common.Style)
	s.Common.CSVComma, _ = getScalar(directives, "csv-comma", s.Common.CSVComma)
	s.Common.CSVComment, _ = getScalar(directives, "csv-comment", s.Common.CSVComment)
	s.Common.CSVSkipLines, _ = getInt(directives, "csv-skip-lines", s.Common.CSVSkipLines)
	s.Common.CSVHeaders, _ = getBool(directives, "csv-headers", s.Common.CSVHeaders)
	s.Common.CSVNoInfer, _ = getBool(directives, "csv-no-infer", s.Common.CSVNoInfer)
	s.Common.Profile, _ = getBool(directives, "profile", s.Common.Profile)
	s.Common.DataPaths, _ = getVector(directives, "data-paths", s.Common.DataPaths)
	s.Common.BundlePaths, _ = getVector(directives, "bundle-paths", s.Common.BundlePaths)
	s.Common.Raw, _ = getBool(directives, "raw", s.Common.Raw)
	s.Common.Save, _ = getScalar(directives, "save-bundle", s.Common.Save)
	s.Common.Template, _ = getScalar(directives, "template", s.Common.Template)
	s.Common.Check, _ = getBool(directives, "check", s.Common.Check)
	s.Common.V0Compatible, _ = getBool(directives, "v0-compatible", s.Common.V0Compatible)
	silentQueries, _ := getVector(directives, "silent-query", []string{})

	// This would otherwise be silently rewritten to json later on.
	if s.Common.InputFormat != "" {
		if _, err := rqio.SelectInputHandler(s.Common.InputFormat); err != nil {
			return err
		}
	}
	if s.Common.OutputFormat != "" {
		if _, err := rqio.SelectOutputHandler(s.Common.OutputFormat); err != nil {
			return err
		}
	}

	var reader io.Reader = stdin

	// If reading from a terminal, silently assume {} as an empty input.
	if _, ok := stdin.(*os.File); ok {
		if (stdin.(*os.File) == os.Stdin) && isatty.IsTerminal(os.Stdin.Fd()) && !s.Common.NoTTYCheck {
			buf := &bytes.Buffer{}
			buf.Write([]byte("{}"))
			reader = buf
		}
	}

	inputDefaults, err := defaultInputOptions(s.Common)
	if err != nil {
		return err
	}

	inspec := s.Common.Input
	if s.Common.InPlace != "" {
		if inspec != "" {
			return errors.New("rq: input cannot be used with in-place")
		}
		inspec = s.Common.InPlace
	}

	inputDs, err := rqio.ParseDataSpec(inspec)
	if err != nil {
		return err
	}

	inputDs.ResolveDefaults(inputDefaults)

	// If the dataspec has a FilePath, then we take precedence over
	// whatever stdin is.
	if inputDs.FilePath != "" {
		f, err := os.Open(inputDs.FilePath)
		if err != nil {
			return err
		}
		defer func() { _ = f.Close() }()
		reader = f
	}

	// Note that we have already set up the reader to either be standard
	// in, or the given file path.
	parsed, err := rqio.LoadInputFromReader(inputDs, reader)
	if err != nil {
		return err
	}

	var outputData interface{}
	data, err := loadData(s.Common.DataPaths, s.Common)
	if err != nil {
		return err
	}

	opt := ast.ParserOptions{}
	if s.Common.V0Compatible {
		opt.RegoVersion = ast.RegoV0
	}

	module, err := ast.ParseModuleWithOpts(s.File, string(scriptBytes), opt)
	if err != nil {
		return err
	}

	var regoVersion int = 1
	if s.Common.V0Compatible {
		regoVersion = 0
	}

	scriptBundle := &opabundle.Bundle{
		Manifest: opabundle.Manifest{
			Roots:       &[]string{},
			RegoVersion: &regoVersion,
		},
		Modules: []opabundle.ModuleFile{
			{
				Path:   s.File,
				Raw:    scriptBytes,
				Parsed: module,
			},
		},
	}

	bundles := make([]*opabundle.Bundle, len(s.Common.BundlePaths))
	for i, p := range s.Common.BundlePaths {
		l := loader.NewFileLoader()
		if s.Common.V0Compatible {
			l.WithRegoVersion(ast.RegoV0)
		}
		b, err := l.AsBundle(p)
		if err != nil {
			return err
		}
		if b == nil {
			return fmt.Errorf("loading bundle from path '%s' returned nil", p)
		}
		bundles[i] = b
	}
	bundles = append(bundles, scriptBundle)

	bundle := mergeBundles(bundles, nil, &regoVersion)

	for _, q := range silentQueries {
		_, err = runRegoQuery(parsed, q, data, bundle, regoVersion)
		if err != nil {
			return err
		}
	}

	if s.Common.Save != "" {
		return saveBundle(data, bundle, s.Common.Save, &regoVersion)
	}

	if s.Common.Raw {
		s.Common.OutputFormat = "raw"
	}

	outputIsTTY := s.Common.Output == "" && isatty.IsTerminal(os.Stdout.Fd()) && !s.Common.NoTTYCheck && s.Common.InPlace == ""
	outputDefaults, err := defaultOutputOptions(s.Common, outputIsTTY)
	if err != nil {
		return err
	}

	outspec := s.Common.Output
	if s.Common.InPlace != "" {
		if outspec != "" {
			return errors.New("rq: output cannot be used with in-place")
		}
		outspec = s.Common.InPlace
	}

	outputDs, err := rqio.ParseDataSpec(outspec)
	if err != nil {
		return err
	}

	outputDs.ResolveDefaults(outputDefaults)

	writer := stdout
	if outputDs.FilePath != "" {
		f, err := os.OpenFile(outputDs.FilePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
		if err != nil {
			return err
		}
		defer func() { _ = f.Close() }()
		writer = f
	}

	outputData, err = runRegoQuery(parsed, query, data, bundle, regoVersion)
	if err != nil {
		return err
	}

	if s.Common.Check {
		if testCheck(outputData) {
			forceNonzeroExit = true
		}
	}

	err = rqio.WriteOutputToWriter(outputData, outputDs, writer)
	if err != nil {
		return err
	}

	if (!s.Common.Ugly) && (outputDs.Format != "null") {
		// Trailing newline
		if _, err := writer.Write([]byte("\n")); err != nil {
			return err
		}
	}

	return nil
}

// QueryCmd implements the "rq query" subcommand.
type QueryCmd struct {
	*Common
	Query string `help:"Rego Query" arg:"" default:"input"`
}

// Run implements the main method for the "rq query" subcommand.
func (q *QueryCmd) Run(globals *Globals) error {
	advice := helpfulAdvice(q.Query, q.Common)
	if advice != nil {
		fmt.Fprintf(os.Stderr, "%s\n", advice.Error())
	}

	// This would otherwise be silently rewritten to json later on.
	if q.Common.InputFormat != "" {
		if _, err := rqio.SelectInputHandler(q.Common.InputFormat); err != nil {
			return err
		}
	}
	if q.Common.OutputFormat != "" {
		if _, err := rqio.SelectOutputHandler(q.Common.OutputFormat); err != nil {
			return err
		}
	}

	var reader io.Reader = stdin

	// If reading from a terminal, silently assume {} as an empty input.
	if _, ok := stdin.(*os.File); ok {
		if (stdin.(*os.File) == os.Stdin) && isatty.IsTerminal(os.Stdin.Fd()) && !q.Common.NoTTYCheck {
			buf := &bytes.Buffer{}
			buf.Write([]byte("{}"))
			reader = buf
		}
	}

	inputDefaults, err := defaultInputOptions(q.Common)
	if err != nil {
		return err
	}

	inspec := q.Common.Input
	if q.Common.InPlace != "" {
		if inspec != "" {
			return errors.New("rq: -I/--input cannot be used with -p/--in-place")
		}
		inspec = q.Common.InPlace
	}

	inputDs, err := rqio.ParseDataSpec(inspec)
	if err != nil {
		return err
	}

	inputDs.ResolveDefaults(inputDefaults)

	// If the dataspec has a FilePath, then we take precedence over
	// whatever stdin is.
	if inputDs.FilePath != "" {
		f, err := os.Open(inputDs.FilePath)
		if err != nil {
			return err
		}
		defer func() { _ = f.Close() }()
		reader = f
	}

	// Note that we have already set up the reader to either be standard
	// in, or the given file path.
	parsed, err := rqio.LoadInputFromReader(inputDs, reader)
	if err != nil {
		return err
	}

	if strings.TrimSpace(q.Query) == "" {
		q.Query = "input"
	}

	regoVersion := 1
	if q.Common.V0Compatible {
		regoVersion = 0
	}

	bundles := make([]*opabundle.Bundle, len(q.Common.BundlePaths))
	for i, p := range q.Common.BundlePaths {
		l := loader.NewFileLoader()
		if q.Common.V0Compatible {
			l.WithRegoVersion(ast.RegoV0)
		}
		b, err := l.AsBundle(p)
		if err != nil {
			return err
		}
		if b == nil {
			return fmt.Errorf("loading bundle from path '%s' returned nil", p)
		}
		bundles[i] = b
	}
	bundle := mergeBundles(bundles, nil, &regoVersion)

	if q.Common.Save != "" {
		data, err := loadData(q.Common.DataPaths, q.Common)
		if err != nil {
			return err
		}

		if q.Common.Save != "" {
			return saveBundle(data, bundle, q.Common.Save, &regoVersion)
		}
	}

	if q.Common.Raw {
		q.Common.OutputFormat = "raw"
	}

	outputIsTTY := q.Common.Output == "" && isatty.IsTerminal(os.Stdout.Fd()) && !q.Common.NoTTYCheck && q.Common.InPlace == ""
	outputDefaults, err := defaultOutputOptions(q.Common, outputIsTTY)
	if err != nil {
		return err
	}

	outspec := q.Common.Output
	if q.Common.InPlace != "" {
		if outspec != "" {
			return errors.New("rq: -O/--output cannot be used with -p/--in-place")
		}
		outspec = q.Common.InPlace
	}

	outputDs, err := rqio.ParseDataSpec(outspec)
	if err != nil {
		return err
	}

	outputDs.ResolveDefaults(outputDefaults)

	writer := stdout
	if outputDs.FilePath != "" {
		f, err := os.OpenFile(outputDs.FilePath, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0644)
		if err != nil {
			return err
		}
		defer func() { _ = f.Close() }()
		writer = f
	}

	var outputData interface{}
	if strings.TrimSpace(q.Query) == "input" {
		outputData = parsed
	} else {
		data, err := loadData(q.Common.DataPaths, q.Common)
		if err != nil {
			return err
		}

		outputData, err = runRegoQuery(parsed, q.Query, data, bundle, regoVersion)
		if err != nil {
			return err
		}
	}

	if q.Common.Check {
		if testCheck(outputData) {
			forceNonzeroExit = true
		}
	}

	err = rqio.WriteOutputToWriter(outputData, outputDs, writer)
	if err != nil {
		return err
	}

	if (!q.Common.Ugly) && (outputDs.Format != "null") {
		// Trailing newline
		if _, err := writer.Write([]byte("\n")); err != nil {
			return err
		}
	}

	return nil
}
