Go 関数を完成させる

Go を使用して CSV ファイルに基づいて SQL スクリプトを生成する

モデルにクエリを直接実行して、返された結果をテストできます。その場合、Cloud コンソールでさまざまなパラメータ値を使用することも、Vertex AI API を直接呼び出すこともできます。

システム指示

あなたは SQL データベースを専門とする Go デベロッパーです。

自由形式のプロンプト

CSV を SQL に変換するコードがあります。テーブルを作成する関数とスクリプトを挿入する関数が完成していません。この 2 つの関数を完成させてもらえますか?

convert.go

package main

import (
  "encoding/csv"
  "fmt"
  "os"
  "strconv"
  "strings"
  "time"
)

func inferDataType(value string, dialect string) string {
  value = strings.TrimSpace(value) // Trim leading and trailing spaces
  if _, err := strconv.Atoi(value); err == nil {
    return getIntegerType(dialect)
  }
  if _, err := strconv.ParseFloat(value, 64); err == nil {
    return getFloatType(dialect)
  }
  if _, err := time.Parse(time.RFC3339, value); err == nil {
    return getTimestampType(dialect)
  }
  if _, err := time.Parse("2006-01-02 15:04:05", value); err == nil {
    return getDateTimeType(dialect)
  }
  if _, err := time.Parse("2006-01-02", value); err == nil {
    return getDateType(dialect)
  }
  if len(value) > 255 {
    return "TEXT"
  }
  return "VARCHAR(255)"
}

func getIntegerType(dialect string) string {
  switch dialect {
  case "mysql":
    return "INT"
  default:
    return "INTEGER"
  }
}

func getFloatType(dialect string) string {
  switch dialect {
  case "mysql":
    return "FLOAT"
  default:
    return "REAL"
  }
}

func getTimestampType(dialect string) string {
  switch dialect {
  case "postgres":
    return "TIMESTAMP"
  case "mysql":
    return "TIMESTAMP"
  case "sqlite":
    return "DATETIME"
  default:
    return "TEXT"
  }
}

func getDateTimeType(dialect string) string {
  switch dialect {
  case "postgres":
    return "TIMESTAMP"
  case "mysql":
    return "DATETIME"
  case "sqlite":
    return "DATETIME"
  default:
    return "TEXT"
  }
}

func getDateType(dialect string) string {
  switch dialect {
  case "postgres", "mysql", "sqlite":
    return "DATE"
  default:
    return "TEXT"
  }
}

func validateDataType(value string, dataType string) error {
  value = strings.TrimSpace(value)
  switch dataType {
  case "INTEGER":
    if _, err := strconv.Atoi(value); err != nil {
      return err
    }
  case "REAL":
    if _, err := strconv.ParseFloat(value, 64); err != nil {
      return err
    }
  case "TIMESTAMP":
    if _, err := time.Parse(time.RFC3339, value); err != nil {
      return err
    }
  case "DATETIME":
    if _, err := time.Parse("2006-01-02 15:04:05", value); err != nil {
      return err
    }
  case "DATE":
    if _, err := time.Parse("2006-01-02", value); err != nil {
      return err
    }
  }
  return nil
}

func generateCreateTableScript(tableName string, headers []string, sampleData []string, dialect string) string {

}

func generateInsertScript(tableName string, headers []string, rows [][]string, dataTypes []string) (string, error) {

}

func csvToSQL(csvFilePath, tableName, dialect string) (string, string, error) {
  file, err := os.Open(csvFilePath)
  if err != nil {
    return "", "", err
  }
  defer file.Close()

  reader := csv.NewReader(file)
  headers, err := reader.Read()
  if err != nil {
    return "", "", err
  }

  // Convert headers to snake_case
  fixedHeaders := make([]string, len(headers))
  for i, header := range headers {

    fixedHeader := strings.ReplaceAll(strings.TrimSpace(header), " ", "_")
    fixedHeaders[i] = fixedHeader
  }

  sampleData, err := reader.Read()
  if err != nil {
    return "", "", err
  }

  rows, err := reader.ReadAll()
  if err != nil {
    return "", "", err
  }

  dataTypes := make([]string, len(fixedHeaders))
  for i, sample := range sampleData {
    dataTypes[i] = inferDataType(sample, dialect)
  }

  createTableScript := generateCreateTableScript(tableName, fixedHeaders, sampleData, dialect)
  insertScript, err := generateInsertScript(tableName, fixedHeaders, rows, dataTypes)
  if err != nil {
    return "", "", err
  }

  return createTableScript, insertScript, nil
}

main.go

package main

import (
  "flag"
  "fmt"
  "os"
  "path/filepath"
)

func main() {
  csvFilePath := flag.String("f", "", "Path to the CSV file")
  tableName := flag.String("t", "default_table", "Name of the table")
  dialect := flag.String("d", "postgres", "SQL dialect (postgres, mysql, sqlite)")

  flag.Parse()

  // Check if file path is provided
  if *csvFilePath == "" {
    fmt.Println("Error: CSV file path is required")
    flag.Usage()
    os.Exit(1)
  }

  createScript, insertScript, err := csvToSQL(*csvFilePath, *tableName, *dialect)
  if err != nil {
    fmt.Println("Error:", err)
    return
  }

  createFilePath := filepath.Join(".", fmt.Sprintf("%s_create.sql", *tableName))
  err = os.WriteFile(createFilePath, []byte(createScript), 0644)
  if err != nil {
    fmt.Println("Error writing CREATE TABLE script:", err)
    return
  }

  insertFilePath := filepath.Join(".", fmt.Sprintf("%s_insert.sql", *tableName))
  err = os.WriteFile(insertFilePath, []byte(insertScript), 0644)
  if err != nil {
    fmt.Println("Error writing INSERT scripts:", err)
    return
  }

  fmt.Printf("CREATE TABLE script written to: %s\n", createFilePath)
  fmt.Printf("INSERT scripts written to: %s\n", insertFilePath)
}

レスポンス

package main

import (
    "encoding/csv"
    "fmt"
    "os"
    "strconv"
    "strings"
    "time"
)

func inferDataType(value string, dialect string) string {
    value = strings.TrimSpace(value) // Trim leading and trailing spaces
    if _, err := strconv.Atoi(value); err == nil {
        return getIntegerType(dialect)
    }
    if _, err := strconv.ParseFloat(value, 64); err == nil {
        return getFloatType(dialect)
    }
    if _, err := time.Parse(time.RFC3339, value); err == nil {
        return getTimestampType(dialect)
    }
    if _, err := time.Parse("2006-01-02 15:04:05", value); err == nil {
        return getDateTimeType(dialect)
    }
    if _, err := time.Parse("2006-01-02", value); err == nil {
        return getDateType(dialect)
    }
    if len(value) > 255 {
        return "TEXT"
    }
    return "VARCHAR(255)"
}

func getIntegerType(dialect string) string {
    switch dialect {
    case "mysql":
        return "INT"
    default:
        return "INTEGER"
    }
}

func getFloatType(dialect string) string {
    switch dialect {
    case "mysql":
        return "FLOAT"
    default:
        return "REAL"
    }
}

func getTimestampType(dialect string) string {
    switch dialect {
    case "postgres":
        return "TIMESTAMP"
    case "mysql":
        return "TIMESTAMP"
    case "sqlite":
        return "DATETIME"
    default:
        return "TEXT"
    }
}

func getDateTimeType(dialect string) string {
    switch dialect {
    case "postgres":
        return "TIMESTAMP"
    case "mysql":
        return "DATETIME"
    case "sqlite":
        return "DATETIME"
    default:
        return "TEXT"
    }
}

func getDateType(dialect string) string {
    switch dialect {
    case "postgres", "mysql", "sqlite":
        return "DATE"
    default:
        return "TEXT"
    }
}

func validateDataType(value string, dataType string) error {
    value = strings.TrimSpace(value)
    switch dataType {
    case "INTEGER":
        if _, err := strconv.Atoi(value); err != nil {
            return err
        }
    case "REAL":
        if _, err := strconv.ParseFloat(value, 64); err != nil {
            return err
        }
    case "TIMESTAMP":
        if _, err := time.Parse(time.RFC3339, value); err != nil {
            return err
        }
    case "DATETIME":
        if _, err := time.Parse("2006-01-02 15:04:05", value); err != nil {
            return err
        }
    case "DATE":
        if _, err := time.Parse("2006-01-02", value); err != nil {
            return err
        }
    }
    return nil
}

func generateCreateTableScript(tableName string, headers []string, sampleData []string, dialect string) string {
    script := fmt.Sprintf("CREATE TABLE %s (\n", tableName)
    for i, header := range headers {
        dataType := inferDataType(sampleData[i], dialect)
        script += fmt.Sprintf("  %s %s", header, dataType)
        if i < len(headers)-1 {
            script += ",\n"
        } else {
            script += "\n"
        }
    }
    script += ");"
    return script
}

func generateInsertScript(tableName string, headers []string, rows [][]string, dataTypes []string) (string, error) {
    script := ""
    for _, row := range rows {
        if len(row) != len(headers) {
            return "", fmt.Errorf("row length mismatch")
        }
        values := make([]string, len(row))
        for i, value := range row {
            err := validateDataType(value, dataTypes[i])
            if err != nil {
                return "", fmt.Errorf("data type validation failed for value '%s': %w", value, err)
            }
            //Escape strings for SQL injection prevention.  This is a simplification and should be improved for production use.
            if dataTypes[i] == "VARCHAR(255)" || dataTypes[i] == "TEXT" {
                values[i] = "'" + strings.ReplaceAll(value, "'", "''") + "'"
            } else {
                values[i] = value
            }

        }
        script += fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);\n", tableName, strings.Join(headers, ","), strings.Join(values, ","))
    }
    return script, nil
}

func csvToSQL(csvFilePath, tableName, dialect string) (string, string, error) {
    file, err := os.Open(csvFilePath)
    if err != nil {
        return "", "", err
    }
    defer file.Close()

    reader := csv.NewReader(file)
    headers, err := reader.Read()
    if err != nil {
        return "", "", err
    }

    // Convert headers to snake_case
    fixedHeaders := make([]string, len(headers))
    for i, header := range headers {

        fixedHeader := strings.ReplaceAll(strings.TrimSpace(header), " ", "_")
        fixedHeaders[i] = fixedHeader
    }

    sampleData, err := reader.Read()
    if err != nil {
        return "", "", err
    }

    rows, err := reader.ReadAll()
    if err != nil {
        return "", "", err
    }

    dataTypes := make([]string, len(fixedHeaders))
    for i, sample := range sampleData {
        dataTypes[i] = inferDataType(sample, dialect)
    }

    createTableScript := generateCreateTableScript(tableName, fixedHeaders, sampleData, dialect)
    insertScript, err := generateInsertScript(tableName, fixedHeaders, rows, dataTypes)
    if err != nil {
        return "", "", err
    }

    return createTableScript, insertScript, nil
}

更新されたコードには、完成した generateCreateTableScript 関数と generateInsertScript 関数が含まれています。generateCreateTableScript は、推定されたデータ型に基づいて CREATE TABLE ステートメントを構築します。generateInsertScript は、CSV の行ごとに INSERT ステートメントを生成し、単純な SQL インジェクションを防ぐための基本的な SQL エスケープも含めます。注: generateInsertScript のエスケープは基本的なものです。より高度な SQL インジェクション攻撃を防ぐため、本番環境で使用する場合は、パラメータ化された適切なクエリ メカニズムに置き換える必要があります。また、エラー処理も改善して、より有用なエラー メッセージが表示されるようにしました。メイン関数で潜在的なエラーを適切に処理してください。

モデル: gemini-1.5-flash-002
温度: 0.2
最大出力トークン: 8192
TopK: 40.0
TopP: 0.95