[pbs-devel] [PATCH v2 proxmox-backup 09/20] server/rest: add ApiAuth trait to make user auth generic

Stefan Reiter s.reiter at proxmox.com
Wed Mar 24 16:18:16 CET 2021


This allows switching the base user identification/authentication method
in the rest server. Will initially be used for single file restore VMs,
where authentication is based on a ticket file, not the PBS user
backend (PAM/local).

Signed-off-by: Stefan Reiter <s.reiter at proxmox.com>
---
 src/bin/proxmox-backup-api.rs   |  13 ++-
 src/bin/proxmox-backup-proxy.rs |   7 +-
 src/server/auth.rs              | 160 ++++++++++++++++++--------------
 src/server/config.rs            |  17 +++-
 src/server/rest.rs              |  50 +++++-----
 5 files changed, 144 insertions(+), 103 deletions(-)

diff --git a/src/bin/proxmox-backup-api.rs b/src/bin/proxmox-backup-api.rs
index 7d800259..e514a801 100644
--- a/src/bin/proxmox-backup-api.rs
+++ b/src/bin/proxmox-backup-api.rs
@@ -6,8 +6,11 @@ use proxmox::api::RpcEnvironmentType;
 
 //use proxmox_backup::tools;
 //use proxmox_backup::api_schema::config::*;
-use proxmox_backup::server::rest::*;
-use proxmox_backup::server;
+use proxmox_backup::server::{
+    self,
+    auth::default_api_auth,
+    rest::*,
+};
 use proxmox_backup::tools::daemon;
 use proxmox_backup::auth_helpers::*;
 use proxmox_backup::config;
@@ -53,7 +56,11 @@ async fn run() -> Result<(), Error> {
     let _ = csrf_secret(); // load with lazy_static
 
     let mut config = server::ApiConfig::new(
-        buildcfg::JS_DIR, &proxmox_backup::api2::ROUTER, RpcEnvironmentType::PRIVILEGED)?;
+        buildcfg::JS_DIR,
+        &proxmox_backup::api2::ROUTER,
+        RpcEnvironmentType::PRIVILEGED,
+        default_api_auth(),
+    )?;
 
     let mut commando_sock = server::CommandoSocket::new(server::our_ctrl_sock());
 
diff --git a/src/bin/proxmox-backup-proxy.rs b/src/bin/proxmox-backup-proxy.rs
index 541d34b5..7e026455 100644
--- a/src/bin/proxmox-backup-proxy.rs
+++ b/src/bin/proxmox-backup-proxy.rs
@@ -14,6 +14,7 @@ use proxmox::api::RpcEnvironmentType;
 use proxmox_backup::{
     backup::DataStore,
     server::{
+        auth::default_api_auth,
         WorkerTask,
         ApiConfig,
         rest::*,
@@ -84,7 +85,11 @@ async fn run() -> Result<(), Error> {
     let _ = csrf_secret(); // load with lazy_static
 
     let mut config = ApiConfig::new(
-        buildcfg::JS_DIR, &proxmox_backup::api2::ROUTER, RpcEnvironmentType::PUBLIC)?;
+        buildcfg::JS_DIR,
+        &proxmox_backup::api2::ROUTER,
+        RpcEnvironmentType::PUBLIC,
+        default_api_auth(),
+    )?;
 
     // Enable experimental tape UI if tape.cfg exists
     if Path::new("/etc/proxmox-backup/tape.cfg").exists() {
diff --git a/src/server/auth.rs b/src/server/auth.rs
index 24151886..7239535d 100644
--- a/src/server/auth.rs
+++ b/src/server/auth.rs
@@ -1,6 +1,8 @@
 //! Provides authentication primitives for the HTTP server
 use anyhow::{bail, format_err, Error};
 
+use std::sync::Arc;
+
 use crate::tools::ticket::Ticket;
 use crate::auth_helpers::*;
 use crate::tools;
@@ -10,6 +12,17 @@ use crate::api2::types::{Authid, Userid};
 use hyper::header;
 use percent_encoding::percent_decode_str;
 
+pub trait ApiAuth {
+    type AuthData;
+    fn extract_auth_data(&self, headers: &http::HeaderMap) -> Option<Self::AuthData>;
+    fn check_auth(
+        &self,
+        method: &hyper::Method,
+        auth_data: &Self::AuthData,
+        user_info: &CachedUserInfo,
+    ) -> Result<Authid, Error>;
+}
+
 pub struct UserAuthData {
     ticket: String,
     csrf_token: Option<String>,
@@ -20,83 +33,92 @@ pub enum AuthData {
     ApiToken(String),
 }
 
-pub fn extract_auth_data(headers: &http::HeaderMap) -> Option<AuthData> {
-    if let Some(raw_cookie) = headers.get(header::COOKIE) {
-        if let Ok(cookie) = raw_cookie.to_str() {
-            if let Some(ticket) = tools::extract_cookie(cookie, "PBSAuthCookie") {
-                let csrf_token = match headers.get("CSRFPreventionToken").map(|v| v.to_str()) {
-                    Some(Ok(v)) => Some(v.to_owned()),
-                    _ => None,
-                };
-                return Some(AuthData::User(UserAuthData {
-                    ticket,
-                    csrf_token,
-                }));
-            }
-        }
-    }
-
-    match headers.get(header::AUTHORIZATION).map(|v| v.to_str()) {
-        Some(Ok(v)) => {
-            if v.starts_with("PBSAPIToken ") || v.starts_with("PBSAPIToken=") {
-                Some(AuthData::ApiToken(v["PBSAPIToken ".len()..].to_owned()))
-            } else {
-                None
-            }
-        },
-        _ => None,
-    }
+pub struct UserApiAuth {}
+pub fn default_api_auth() -> Arc<UserApiAuth> {
+    Arc::new(UserApiAuth {})
 }
 
-pub fn check_auth(
-    method: &hyper::Method,
-    auth_data: &AuthData,
-    user_info: &CachedUserInfo,
-) -> Result<Authid, Error> {
-    match auth_data {
-        AuthData::User(user_auth_data) => {
-            let ticket = user_auth_data.ticket.clone();
-            let ticket_lifetime = tools::ticket::TICKET_LIFETIME;
+impl ApiAuth for UserApiAuth {
+    type AuthData = AuthData;
 
-            let userid: Userid = Ticket::<super::ticket::ApiTicket>::parse(&ticket)?
-                .verify_with_time_frame(public_auth_key(), "PBS", None, -300..ticket_lifetime)?
-                .require_full()?;
-
-            let auth_id = Authid::from(userid.clone());
-            if !user_info.is_active_auth_id(&auth_id) {
-                bail!("user account disabled or expired.");
-            }
-
-            if method != hyper::Method::GET {
-                if let Some(csrf_token) = &user_auth_data.csrf_token {
-                    verify_csrf_prevention_token(csrf_secret(), &userid, &csrf_token, -300, ticket_lifetime)?;
-                } else {
-                    bail!("missing CSRF prevention token");
+    fn extract_auth_data(&self, headers: &http::HeaderMap) -> Option<Self::AuthData> {
+        if let Some(raw_cookie) = headers.get(header::COOKIE) {
+            if let Ok(cookie) = raw_cookie.to_str() {
+                if let Some(ticket) = tools::extract_cookie(cookie, "PBSAuthCookie") {
+                    let csrf_token = match headers.get("CSRFPreventionToken").map(|v| v.to_str()) {
+                        Some(Ok(v)) => Some(v.to_owned()),
+                        _ => None,
+                    };
+                    return Some(AuthData::User(UserAuthData {
+                        ticket,
+                        csrf_token,
+                    }));
                 }
             }
+        }
 
-            Ok(auth_id)
-        },
-        AuthData::ApiToken(api_token) => {
-            let mut parts = api_token.splitn(2, ':');
-            let tokenid = parts.next()
-                .ok_or_else(|| format_err!("failed to split API token header"))?;
-            let tokenid: Authid = tokenid.parse()?;
+        match headers.get(header::AUTHORIZATION).map(|v| v.to_str()) {
+            Some(Ok(v)) => {
+                if v.starts_with("PBSAPIToken ") || v.starts_with("PBSAPIToken=") {
+                    Some(AuthData::ApiToken(v["PBSAPIToken ".len()..].to_owned()))
+                } else {
+                    None
+                }
+            },
+            _ => None,
+        }
+    }
 
-            if !user_info.is_active_auth_id(&tokenid) {
-                bail!("user account or token disabled or expired.");
+    fn check_auth(
+        &self,
+        method: &hyper::Method,
+        auth_data: &Self::AuthData,
+        user_info: &CachedUserInfo,
+    ) -> Result<Authid, Error> {
+        match auth_data {
+            AuthData::User(user_auth_data) => {
+                let ticket = user_auth_data.ticket.clone();
+                let ticket_lifetime = tools::ticket::TICKET_LIFETIME;
+
+                let userid: Userid = Ticket::<super::ticket::ApiTicket>::parse(&ticket)?
+                    .verify_with_time_frame(public_auth_key(), "PBS", None, -300..ticket_lifetime)?
+                    .require_full()?;
+
+                let auth_id = Authid::from(userid.clone());
+                if !user_info.is_active_auth_id(&auth_id) {
+                    bail!("user account disabled or expired.");
+                }
+
+                if method != hyper::Method::GET {
+                    if let Some(csrf_token) = &user_auth_data.csrf_token {
+                        verify_csrf_prevention_token(csrf_secret(), &userid, &csrf_token, -300, ticket_lifetime)?;
+                    } else {
+                        bail!("missing CSRF prevention token");
+                    }
+                }
+
+                Ok(auth_id)
+            },
+            AuthData::ApiToken(api_token) => {
+                let mut parts = api_token.splitn(2, ':');
+                let tokenid = parts.next()
+                    .ok_or_else(|| format_err!("failed to split API token header"))?;
+                let tokenid: Authid = tokenid.parse()?;
+
+                if !user_info.is_active_auth_id(&tokenid) {
+                    bail!("user account or token disabled or expired.");
+                }
+
+                let tokensecret = parts.next()
+                    .ok_or_else(|| format_err!("failed to split API token header"))?;
+                let tokensecret = percent_decode_str(tokensecret)
+                    .decode_utf8()
+                    .map_err(|_| format_err!("failed to decode API token header"))?;
+
+                crate::config::token_shadow::verify_secret(&tokenid, &tokensecret)?;
+
+                Ok(tokenid)
             }
-
-            let tokensecret = parts.next()
-                .ok_or_else(|| format_err!("failed to split API token header"))?;
-            let tokensecret = percent_decode_str(tokensecret)
-                .decode_utf8()
-                .map_err(|_| format_err!("failed to decode API token header"))?;
-
-            crate::config::token_shadow::verify_secret(&tokenid, &tokensecret)?;
-
-            Ok(tokenid)
         }
     }
 }
-
diff --git a/src/server/config.rs b/src/server/config.rs
index 9094fa80..50ee5b85 100644
--- a/src/server/config.rs
+++ b/src/server/config.rs
@@ -13,8 +13,9 @@ use proxmox::api::{ApiMethod, Router, RpcEnvironmentType};
 use proxmox::tools::fs::{create_path, CreateOptions};
 
 use crate::tools::{FileLogger, FileLogOptions};
+use super::auth::ApiAuth;
 
-pub struct ApiConfig {
+pub struct ApiConfig<A: 'static + Send + Sync> {
     basedir: PathBuf,
     router: &'static Router,
     aliases: HashMap<String, PathBuf>,
@@ -23,11 +24,16 @@ pub struct ApiConfig {
     template_files: RwLock<HashMap<String, (SystemTime, PathBuf)>>,
     request_log: Option<Arc<Mutex<FileLogger>>>,
     pub enable_tape_ui: bool,
+    pub api_auth: Arc<dyn ApiAuth<AuthData = A> + Send + Sync>,
 }
 
-impl ApiConfig {
-
-    pub fn new<B: Into<PathBuf>>(basedir: B, router: &'static Router, env_type: RpcEnvironmentType) -> Result<Self, Error> {
+impl<A: 'static + Send + Sync> ApiConfig<A> {
+    pub fn new<B: Into<PathBuf>>(
+        basedir: B,
+        router: &'static Router,
+        env_type: RpcEnvironmentType,
+        api_auth: Arc<dyn ApiAuth<AuthData = A> + Send + Sync>,
+    ) -> Result<Self, Error> {
         Ok(Self {
             basedir: basedir.into(),
             router,
@@ -37,7 +43,8 @@ impl ApiConfig {
             template_files: RwLock::new(HashMap::new()),
             request_log: None,
             enable_tape_ui: false,
-       })
+            api_auth,
+        })
     }
 
     pub fn find_method(
diff --git a/src/server/rest.rs b/src/server/rest.rs
index 2169766a..5f355c1d 100644
--- a/src/server/rest.rs
+++ b/src/server/rest.rs
@@ -41,7 +41,6 @@ use proxmox::api::schema::{
 use super::environment::RestEnvironment;
 use super::formatter::*;
 use super::ApiConfig;
-use super::auth::{check_auth, extract_auth_data};
 
 use crate::auth_helpers::*;
 use crate::api2::types::{Authid, Userid};
@@ -51,23 +50,23 @@ use crate::config::cached_user_info::CachedUserInfo;
 
 extern "C"  { fn tzset(); }
 
-pub struct RestServer {
-    pub api_config: Arc<ApiConfig>,
+pub struct RestServer<A: 'static + Send + Sync> {
+    pub api_config: Arc<ApiConfig<A>>,
 }
 
 const MAX_URI_QUERY_LENGTH: usize = 3072;
 
-impl RestServer {
+impl<A: Send + Sync + 'static> RestServer<A> {
 
-    pub fn new(api_config: ApiConfig) -> Self {
+    pub fn new(api_config: ApiConfig<A>) -> Self {
         Self { api_config: Arc::new(api_config) }
     }
 }
 
-impl tower_service::Service<&Pin<Box<tokio_openssl::SslStream<tokio::net::TcpStream>>>> for RestServer {
-    type Response = ApiService;
+impl<A: Send + Sync + 'static> tower_service::Service<&Pin<Box<tokio_openssl::SslStream<tokio::net::TcpStream>>>> for RestServer<A> {
+    type Response = ApiService<A>;
     type Error = Error;
-    type Future = Pin<Box<dyn Future<Output = Result<ApiService, Error>> + Send>>;
+    type Future = Pin<Box<dyn Future<Output = Result<ApiService<A>, Error>> + Send>>;
 
     fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
         Poll::Ready(Ok(()))
@@ -85,10 +84,10 @@ impl tower_service::Service<&Pin<Box<tokio_openssl::SslStream<tokio::net::TcpStr
     }
 }
 
-impl tower_service::Service<&tokio::net::TcpStream> for RestServer {
-    type Response = ApiService;
+impl<A: Send + Sync + 'static> tower_service::Service<&tokio::net::TcpStream> for RestServer<A> {
+    type Response = ApiService<A>;
     type Error = Error;
-    type Future = Pin<Box<dyn Future<Output = Result<ApiService, Error>> + Send>>;
+    type Future = Pin<Box<dyn Future<Output = Result<ApiService<A>, Error>> + Send>>;
 
     fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
         Poll::Ready(Ok(()))
@@ -106,10 +105,10 @@ impl tower_service::Service<&tokio::net::TcpStream> for RestServer {
     }
 }
 
-impl tower_service::Service<&tokio::net::UnixStream> for RestServer {
-    type Response = ApiService;
+impl<A: Send + Sync + 'static> tower_service::Service<&tokio::net::UnixStream> for RestServer<A> {
+    type Response = ApiService<A>;
     type Error = Error;
-    type Future = Pin<Box<dyn Future<Output = Result<ApiService, Error>> + Send>>;
+    type Future = Pin<Box<dyn Future<Output = Result<ApiService<A>, Error>> + Send>>;
 
     fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
         Poll::Ready(Ok(()))
@@ -126,9 +125,9 @@ impl tower_service::Service<&tokio::net::UnixStream> for RestServer {
     }
 }
 
-pub struct ApiService {
+pub struct ApiService<A: 'static + Send + Sync> {
     pub peer: std::net::SocketAddr,
-    pub api_config: Arc<ApiConfig>,
+    pub api_config: Arc<ApiConfig<A>>,
 }
 
 fn log_response(
@@ -214,7 +213,7 @@ fn get_user_agent(headers: &HeaderMap) -> Option<String> {
     }).ok()
 }
 
-impl tower_service::Service<Request<Body>> for ApiService {
+impl<A: 'static + Send + Sync> tower_service::Service<Request<Body>> for ApiService<A> {
     type Response = Response<Body>;
     type Error = Error;
     #[allow(clippy::type_complexity)]
@@ -417,11 +416,11 @@ pub async fn handle_api_request<Env: RpcEnvironment, S: 'static + BuildHasher +
     Ok(resp)
 }
 
-fn get_index(
+fn get_index<A: 'static + Send + Sync>(
     userid: Option<Userid>,
     csrf_token: Option<String>,
     language: Option<String>,
-    api: &Arc<ApiConfig>,
+    api: &Arc<ApiConfig<A>>,
     parts: Parts,
 ) ->  Response<Body> {
 
@@ -573,8 +572,8 @@ fn extract_lang_header(headers: &http::HeaderMap) -> Option<String> {
     None
 }
 
-async fn handle_request(
-    api: Arc<ApiConfig>,
+async fn handle_request<A: 'static + Send + Sync>(
+    api: Arc<ApiConfig<A>>,
     req: Request<Body>,
     peer: &std::net::SocketAddr,
 ) -> Result<Response<Body>, Error> {
@@ -599,6 +598,7 @@ async fn handle_request(
     rpcenv.set_client_ip(Some(*peer));
 
     let user_info = CachedUserInfo::new()?;
+    let auth = &api.api_auth;
 
     let delay_unauth_time = std::time::Instant::now() + std::time::Duration::from_millis(3000);
     let access_forbidden_time = std::time::Instant::now() + std::time::Duration::from_millis(500);
@@ -626,8 +626,8 @@ async fn handle_request(
             }
 
             if auth_required {
-                let auth_result = match extract_auth_data(&parts.headers) {
-                    Some(auth_data) => check_auth(&method, &auth_data, &user_info),
+                let auth_result = match auth.extract_auth_data(&parts.headers) {
+                    Some(auth_data) => auth.check_auth(&method, &auth_data, &user_info),
                     None => Err(format_err!("no authentication credentials provided.")),
                 };
                 match auth_result {
@@ -688,8 +688,8 @@ async fn handle_request(
 
         if comp_len == 0 {
             let language = extract_lang_header(&parts.headers);
-            if let Some(auth_data) = extract_auth_data(&parts.headers) {
-                match check_auth(&method, &auth_data, &user_info) {
+            if let Some(auth_data) = auth.extract_auth_data(&parts.headers) {
+                match auth.check_auth(&method, &auth_data, &user_info) {
                     Ok(auth_id) if !auth_id.is_token() => {
                         let userid = auth_id.user();
                         let new_csrf_token = assemble_csrf_prevention_token(csrf_secret(), userid);
-- 
2.20.1






More information about the pbs-devel mailing list