libpq/
lib.rs

1use std::{
2    ffi::{CString, NulError},
3    fmt::Display,
4    io::{Read, Seek},
5    ops::ControlFlow,
6    os::raw::{c_char, c_void},
7    ptr::null_mut,
8};
9
10use std::fmt::Debug;
11
12use tempfile::Builder;
13
14include!("bindings.rs");
15
16pub struct PgSocket {
17    socket: i32,
18}
19
20pub enum PgSocketPollResult {
21    Timeout,
22    Error(String),
23}
24
25impl Display for PgSocketPollResult {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        match self {
28            PgSocketPollResult::Timeout => write!(f, "Timeout"),
29            PgSocketPollResult::Error(s) => write!(f, "Error: {}", s),
30        }
31    }
32}
33
34impl PgSocket {
35    pub fn poll(
36        &self,
37        read: bool,
38        write: bool,
39        timeout: Option<f64>,
40    ) -> Result<(), PgSocketPollResult> {
41        unsafe {
42            let timeout_ms = match timeout {
43                Some(t) => PQgetCurrentTimeUSec() + (t * 1000000.0) as i64,
44                None => -1,
45            };
46
47            match PQsocketPoll(self.socket, read.into(), write.into(), timeout_ms) {
48                a if a > 0 => Ok(()),
49                0 => Err(PgSocketPollResult::Timeout),
50                _ => Err(PgSocketPollResult::Error(
51                    std::io::Error::last_os_error().to_string(),
52                )),
53            }
54        }
55    }
56}
57pub struct PgConn {
58    conn: *mut PGconn,
59}
60
61unsafe impl Send for PgConn {}
62
63unsafe impl Sync for PgConn {}
64
65pub struct PgResult {
66    res: *mut PGresult,
67}
68
69pub struct PgNotify {
70    notify: *mut PGnotify,
71}
72
73impl PgNotify {
74    pub fn relname(&self) -> String {
75        unsafe {
76            let s = (*self.notify).relname;
77            std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned()
78        }
79    }
80
81    pub fn be_pid(&self) -> i32 {
82        unsafe { (*self.notify).be_pid }
83    }
84
85    pub fn extra(&self) -> String {
86        unsafe {
87            let s = (*self.notify).extra;
88            std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned()
89        }
90    }
91}
92
93impl Drop for PgConn {
94    fn drop(&mut self) {
95        unsafe {
96            PQfinish(self.conn);
97        }
98    }
99}
100
101impl Drop for PgNotify {
102    fn drop(&mut self) {
103        unsafe {
104            PQfreemem(self.notify as *mut c_void);
105        }
106    }
107}
108
109impl Drop for PgResult {
110    fn drop(&mut self) {
111        unsafe {
112            PQclear(self.res);
113        }
114    }
115}
116
117impl PgConn {
118    /// Connect to the database using environment variables.
119    ///
120    /// See the [official doc](https://www.postgresql.org/docs/current/libpq-envars.html).
121    pub fn connect_db_env_vars() -> Result<PgConn, NulError> {
122        Self::connect_db("")
123    }
124
125    pub fn connect_db(s: &str) -> Result<PgConn, NulError> {
126        unsafe {
127            let conninfo = std::ffi::CString::new(s)?;
128            let conn = PQconnectdb(conninfo.as_ptr());
129            Ok(PgConn { conn })
130        }
131    }
132
133    pub fn status(&self) -> ConnStatusType {
134        unsafe { PQstatus(self.conn) }
135    }
136
137    pub fn exec(&self, query: &str) -> Result<PgResult, NulError> {
138        unsafe {
139            let c_query = std::ffi::CString::new(query)?;
140            let res = PQexec(self.conn, c_query.as_ptr());
141            Ok(PgResult { res })
142        }
143    }
144
145    pub fn exec_file(&self, file_path: &str) -> Result<PgResult, NulError> {
146        let content = std::fs::read_to_string(file_path).expect("Failed to read file.");
147        self.exec(&content)
148    }
149
150    pub fn trace(&mut self, file: &str) {
151        unsafe {
152            let c_file = std::ffi::CString::new(file).unwrap();
153            let mode = std::ffi::CString::new("w").unwrap();
154            let fp = fopen(c_file.as_ptr(), mode.as_ptr());
155            PQtrace(self.conn, fp);
156            assert_eq!(fflush(fp), 0);
157        }
158    }
159
160    pub fn untrace(&mut self) {
161        unsafe {
162            PQuntrace(self.conn);
163        }
164    }
165
166    pub fn socket(&self) -> PgSocket {
167        unsafe {
168            PgSocket {
169                socket: PQsocket(self.conn),
170            }
171        }
172    }
173
174    pub fn consume_input(&mut self) -> Result<(), String> {
175        unsafe {
176            if PQconsumeInput(self.conn) == 0 {
177                Err(self.error_message())
178            } else {
179                Ok(())
180            }
181        }
182    }
183
184    pub fn notifies(&mut self) -> Option<PgNotify> {
185        unsafe {
186            let notify = PQnotifies(self.conn);
187            if notify.is_null() {
188                None
189            } else {
190                Some(PgNotify { notify })
191            }
192        }
193    }
194
195    pub fn error_message(&self) -> String {
196        unsafe {
197            let s = PQerrorMessage(self.conn);
198            if s.is_null() {
199                "".to_string()
200            } else {
201                std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned()
202            }
203        }
204    }
205
206    ///
207    /// A callback function to receive notices from the server.
208    /// https://stackoverflow.com/questions/24191249/working-with-c-void-in-an-ffi
209    /// https://adventures.michaelfbryan.com/posts/rust-closures-in-ffi/
210    extern "C" fn ffi_notice_processor<F>(arg: *mut c_void, data: *const c_char)
211    where
212        F: FnMut(String),
213    {
214        unsafe {
215            let s = std::ffi::CStr::from_ptr(data)
216                .to_string_lossy()
217                .into_owned();
218
219            let f = &mut *(arg as *mut F);
220
221            f(s);
222        }
223    }
224
225    pub fn set_notice_processor<F>(&mut self, proc: F) -> Box<F>
226    where
227        F: FnMut(String),
228    {
229        unsafe {
230            let mut b = Box::new(proc);
231            let a = b.as_mut() as *mut F as *mut c_void;
232            PQsetNoticeProcessor(self.conn, Some(Self::ffi_notice_processor::<F>), a);
233            b
234        }
235    }
236
237    extern "C" fn ffi_notice_receiver<F>(arg: *mut c_void, data: *const PGresult)
238    where
239        F: FnMut(PgResult),
240    {
241        unsafe {
242            let s = PgResult {
243                res: data as *mut PGresult,
244            };
245
246            let f = &mut *(arg as *mut F);
247
248            f(s);
249        }
250    }
251
252    /// Sets a notice receiver function to receive notices from the server.
253    /// Notices are sent to the receiver after command execution is completed.
254    /// https://www.postgresql.org/docs/current/libpq-notice-processing.html
255    pub fn set_notice_receiver<F>(&mut self, proc: F) -> Box<F>
256    where
257        F: FnMut(PgResult),
258    {
259        unsafe {
260            let mut b = Box::new(proc);
261            let a = b.as_mut() as *mut F as *mut c_void;
262            PQsetNoticeReceiver(self.conn, Some(Self::ffi_notice_receiver::<F>), a);
263            b
264        }
265    }
266
267    pub fn listen<F, T>(&mut self, timeout_sec: Option<f64>, proc: F) -> Vec<T>
268    where
269        F: Fn(usize, PgNotify) -> ControlFlow<(), Option<T>>,
270    {
271        let mut recvs = Vec::new();
272
273        let mut count = 0;
274
275        loop {
276            match self.socket().poll(true, false, timeout_sec) {
277                Ok(()) => {
278                    self.consume_input().expect("Failed to consume input.");
279
280                    while let Some(notify) = self.notifies() {
281                        match proc(count, notify) {
282                            ControlFlow::Continue(Some(p)) => recvs.push(p),
283                            ControlFlow::Break(()) => {
284                                break;
285                            }
286                            _ => {} // Do nothing
287                        }
288                        self.consume_input().expect("Failed to consume input.");
289                        count += 1;
290                    }
291                }
292                Err(_e) => break,
293            }
294        }
295
296        recvs
297    }
298}
299
300impl PgResult {
301    pub fn status(&self) -> ExecStatusType {
302        unsafe { PQresultStatus(self.res) }
303    }
304
305    pub fn cmd_status(&mut self) -> String {
306        unsafe {
307            let s = PQcmdStatus(self.res);
308            std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned()
309        }
310    }
311
312    pub fn error_message(&self) -> String {
313        unsafe {
314            let s = PQresultErrorMessage(self.res);
315            std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned()
316        }
317    }
318
319    pub fn error_field(&self, field_code: u8) -> Option<String> {
320        unsafe {
321            let s = PQresultErrorField(self.res, field_code.into());
322            if s.is_null() {
323                None
324            } else {
325                Some(std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned())
326            }
327        }
328    }
329
330    /// Print the result to a file.
331    /// See the [official doc](https://www.postgresql.org/docs/current/libpq-exec.html#LIBPQ-PQPRINT
332    pub fn print(
333        &self,
334        filename: &str,
335        header: bool,
336        align: bool,
337        fieldsep: &str,
338        standard: bool,
339        html3: bool,
340        expanded: bool,
341        pager: bool,
342    ) {
343        unsafe {
344            let sep = CString::new(fieldsep).unwrap();
345
346            let printopt = PQprintOpt {
347                header: header.into(),
348                align: align.into(),
349                fieldSep: sep.as_ptr() as *mut c_char,
350                tableOpt: null_mut(),
351                caption: null_mut(),
352                standard: standard.into(),
353                html3: html3.into(),
354                expanded: expanded.into(),
355                pager: pager.into(),
356                fieldName: null_mut(),
357            };
358
359            let fp = fopen(
360                CString::new(filename).unwrap().as_ptr(),
361                CString::new("w").unwrap().as_ptr(),
362            );
363
364            PQprint(fp, self.res, &printopt);
365
366            assert_eq!(fflush(fp), 0);
367            assert_eq!(fclose(fp), 0);
368        }
369    }
370
371    /// Get the value at the specified row and column.
372    /// See also [PQgetvalue](https://www.postgresql.org/docs/current/libpq-exec.html#LIBPQ-PQGETVALUE).
373    pub fn get_value(&self, row: i32, col: i32) -> String {
374        unsafe {
375            let s = PQgetvalue(self.res, row, col);
376            if s.is_null() {
377                "".to_string()
378            } else {
379                std::ffi::CStr::from_ptr(s).to_string_lossy().into_owned()
380            }
381        }
382    }
383}
384
385impl Display for PgResult {
386    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387        let mut temp_file = Builder::new()
388            .prefix("pg-res-")
389            .suffix(".json")
390            .tempfile()
391            .unwrap();
392
393        let temp_path = temp_file.path().to_path_buf();
394
395        self.print(
396            temp_path.as_path().to_str().unwrap(),
397            true,
398            true,
399            "|",
400            true,
401            false,
402            false,
403            false,
404        );
405
406        let mut s = String::new();
407        temp_file
408            .seek(std::io::SeekFrom::Start(0))
409            .expect("Failed to seek to start of temp file.");
410        temp_file
411            .read_to_string(&mut s)
412            .expect("Failed to read temp file.");
413
414        write!(f, "{}", s)
415    }
416}