diff --git a/spec/pg/decoder_spec.cr b/spec/pg/decoder_spec.cr index 1e75b014..4a0c87d0 100644 --- a/spec/pg/decoder_spec.cr +++ b/spec/pg/decoder_spec.cr @@ -1,4 +1,6 @@ require "uuid" +# Extensions must be loaded before spec_helper connects to the DB +require "../../src/pg/extensions/citext" require "../spec_helper" describe PG::Decoders do @@ -92,4 +94,15 @@ describe PG::Decoders do test_decode "path ", "'(1,2,3,4)'::path ", PG::Geo::Path.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0)], closed: true) test_decode "path ", "'[1,2,3,4,5,6]'::path", PG::Geo::Path.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0), PG::Geo::Point.new(5.0, 6.0)], closed: false) test_decode "polygon", "'1,2,3,4,5,6'::polygon", PG::Geo::Polygon.new([PG::Geo::Point.new(1.0, 2.0), PG::Geo::Point.new(3.0, 4.0), PG::Geo::Point.new(5.0, 6.0)]) + + it "decodes extension types on a per-connection basis" do + PG_DB.exec "CREATE EXTENSION IF NOT EXISTS citext" + PG_DB.query "select 'OMG lol'::citext" do |rs| + rs.each do + text = rs.read + text.should be_a PG::CIText + text.should eq "omg LOL" + end + end + end end diff --git a/src/pg/connection.cr b/src/pg/connection.cr index c84cad92..2f637e5e 100644 --- a/src/pg/connection.cr +++ b/src/pg/connection.cr @@ -1,12 +1,22 @@ +require "./decoder" require "../pq/*" module PG class Connection < ::DB::Connection protected getter connection + private EXTENSIONS = [] of Extension + + def self.register_extension(extension : Extension) + EXTENSIONS << extension + end + def initialize(context) super @connection = uninitialized PQ::Connection + @decoders = Hash(Int32, Decoders::Decoder).new do |_, oid| + Decoders.from_oid(oid) + end begin conn_info = PQ::ConnInfo.new(context.uri) @@ -15,6 +25,25 @@ module PG rescue ex raise DB::ConnectionRefused.new(cause: ex) end + + auto_release = @auto_release + @auto_release = false + begin + EXTENSIONS.each(&.load(self)) + ensure + @auto_release = auto_release + end + end + + def register_decoder(decoder : Decoders::Decoder) + decoder.oids.each do |oid| + @decoders[oid] = decoder + end + end + + @checked_oids = Set(Int32).new + def decoder(oid) + @decoders[oid] end def build_prepared_statement(query) : Statement diff --git a/src/pg/extensions.cr b/src/pg/extensions.cr new file mode 100644 index 00000000..be47ad1b --- /dev/null +++ b/src/pg/extensions.cr @@ -0,0 +1,5 @@ +module PG + module Extension + abstract def load(connection : Connection) + end +end diff --git a/src/pg/extensions/citext.cr b/src/pg/extensions/citext.cr new file mode 100644 index 00000000..79e5e63b --- /dev/null +++ b/src/pg/extensions/citext.cr @@ -0,0 +1,66 @@ +require "../../pg" + +module PG + module Extension + # This extension adds support for decoding the Postgres `citext` type. + class CIText + include Extension + + def load(connection) + oid = connection.query_one "SELECT oid FROM pg_type WHERE typname = 'citext'", as: UInt32 + connection.register_decoder Decoder.new([oid.to_i]) + end + + struct Decoder + include Decoders::Decoder + + getter oids : Array(Int32) + + def initialize(@oids : Array(Int32)) + end + + def decode(io, bytesize, oid) + PG::CIText.new(Decoders::StringDecoder.new.decode(io, bytesize, oid)) + end + + def type + PG::CIText + end + end + end + + Connection.register_extension CIText.new + end + + struct CIText + def initialize(text : String) + @text = text + end + + def hash(hasher) + @text.hash(hasher) + end + + def ==(other : self) + self == other.@text + end + + def ==(other : String) + @text.compare(other, case_insensitive: true) + end + + def to_s + @text + end + + def to_s(io) + @text.to_s io + end + end +end + +class String + def ==(other : PG::CIText) + other == self + end +end diff --git a/src/pg/result_set.cr b/src/pg/result_set.cr index e37a1514..8340b144 100644 --- a/src/pg/result_set.cr +++ b/src/pg/result_set.cr @@ -171,7 +171,7 @@ class PG::ResultSet < ::DB::ResultSet end private def decoder(index = @column_index) - Decoders.from_oid(oid(index)) + statement.connection.decoder(oid(index)) end private def oid(index = @column_index)