# coding=utf-8
# -*- coding: UTF-8 -*-
#
# Copyright 2015-2025, Vincenzo Arcidiacono;
# Licensed under the EUPL (the 'Licence');
# You may not use this work except in compliance with the Licence.
# You may obtain a copy of the Licence at: http://ec.europa.eu/idabc/eupl
"""
It provides functions to build the credit application services.
"""
import os
import re
import copy
import math
import json
import stripe
import datetime
import itertools
import schedula as sh
from .csrf import csrf
from .extensions import db
from .security import User
from . import json_secrets
from .security import is_admin
from .locale import lazy_gettext
from flask_security import current_user as cu, auth_required
from flask import jsonify, flash, Blueprint, abort
from sherlock import Lock
from sqlalchemy import (
Column, String, Integer, DateTime, JSON, or_, event, desc, asc
)
from dateutil.relativedelta import relativedelta
from dateutil.rrule import (
rrule, YEARLY, MONTHLY, WEEKLY, DAILY, HOURLY, MINUTELY, SECONDLY
)
from flask_caching import Cache
FREQUENCIES = {
'M': MONTHLY, 'W': WEEKLY, 'D': DAILY, 'Y': YEARLY, 'h': HOURLY,
'm': MINUTELY, 's': SECONDLY
}
_re_freq = re.compile('^(?P<interval>[1-9]\d*)?(?P<freq>[MWDYhms])$')
[docs]
def date_range(start_time, end_time, freq):
d = _re_freq.match(freq).groupdict()
return itertools.pairwise(rrule(
freq=FREQUENCIES[d['freq']], dtstart=start_time,
until=end_time, interval=int(d['interval'] or '1')
))
bp = Blueprint('schedula_credits', __name__)
users_wallet = db.Table(
'users_wallet', db.Model.metadata,
Column('user_id', Integer, db.ForeignKey('user.id'), primary_key=True),
Column('wallet_id', Integer, db.ForeignKey('wallet.id'), primary_key=True)
)
max_date = datetime.datetime(9999, 12, 21, 23, 59, 59)
[docs]
class Wallet(db.Model):
__tablename__ = 'wallet'
id = Column(Integer, primary_key=True)
user_id = Column(Integer, db.ForeignKey('user.id'), unique=True)
user = db.relationship('User', foreign_keys=[user_id])
users = db.relationship('User', secondary=users_wallet)
[docs]
def __repr__(self):
return f'Wallet({self.id}) {self.user.name}'
[docs]
def name(self):
return f"{self.user.firstname or ''} {self.user.lastname or ''}"
[docs]
def lock(self):
return Lock(f'wallet-{self.id}')
[docs]
def subscription(self, day=None, session=db.session):
from flask import current_app as ca
subscriptions = {}
api_key = ca.config['STRIPE_SECRET_KEY']
products = {}
for subscription in stripe.Subscription.list(
customer=user2stripe_customer(), status='active',
expand=['data.items.data.price'], api_key=api_key
).auto_paging_iter():
subs = {}
for item in subscription.get('items').data:
product_id = item.price.product
if product_id in products:
features = products[product_id]
else:
product = stripe.Product.retrieve(
product_id, api_key=api_key
)
products[product_id] = features = {
product_id: dict(product.metadata),
}
for v in stripe.Product.list_features(
product_id, api_key=api_key
).data:
feat = v.entitlement_feature
features[feat.lookup_key] = dict(feat.metadata or {})
subs.update(features)
subs[item.price.id] = dict(item.price.metadata)
subscriptions[subscription.id] = {
k: v for k, v in subs.items() if v
}
return subscriptions
[docs]
def balance(self, product=None, day=None, session=db.session):
day = datetime.datetime.today() if day is None else day
balance = {}
for r in session.query(Txn).filter_by(
wallet_id=self.id,
**({} if product is None else {"product": product})
).filter(Txn.valid_from <= day).filter(Txn.credits != 0).order_by(asc(
Txn.valid_from
)).all():
bal = sh.get_nested_dicts(balance, r.product)
if r.credits > 0:
key = r.expired_at or max_date
else:
while bal:
key = min(bal)
if key < r.valid_from:
bal.pop(key)
else:
break
new_bal = bal.get(key, 0) + r.credits
while bal and new_bal < 0:
bal.pop(key)
key = min(bal)
new_bal = bal.get(key, 0) + new_bal
bal[key] = new_bal
balance = {k: sum([
i for t, i in v.items() if t >= day
], 0) for k, v in balance.items()}
if product is not None:
balance = balance.get(product, 0)
return balance
[docs]
def use(self, product, credits, session=db.session, created_by=None):
assert credits >= 0, 'Credits to be consumed have to be positive.'
with self.lock():
assert self.balance(
product, session=session
) >= credits, 'Insufficient balance.'
if created_by is None:
created_by = cu.id
t = Txn(
wallet_id=self.id, type_id=CHARGE, credits=-credits,
product=product, created_by=created_by
)
session.add(t)
session.commit()
return t.id
[docs]
def charge(self, product, credits, session=db.session):
assert credits >= 0, 'Credits to be added have to be positive.'
with self.lock():
t = Txn(
wallet_id=self.id, type_id=CHARGE, credits=credits,
product=product
)
session.add(t)
session.commit()
return t.id
[docs]
def transfer_to(self, product, credits, to_wallet, session=db.session):
assert credits >= 0, 'Credits to be transfer have to be positive.'
tran_from = Txn(
wallet_id=self.id, type_id=TRANSFER, credits=-credits,
product=product
)
tran_to = Txn(
wallet_id=to_wallet, type_id=TRANSFER, credits=credits,
product=product
)
to_wallet = session.get(Wallet, to_wallet)
assert to_wallet, 'Destination wallet not found.'
assert to_wallet, 'Destination wallet not found.'
with self.lock(), to_wallet.lock():
assert self.balance(
product, session=session
) >= credits, 'Insufficient balance.'
session.add_all([tran_from, tran_to])
session.commit()
return tran_from.id, tran_to.id
[docs]
@bp.route('/balance', methods=['GET'])
@bp.route('/balance/<int:wallet_id>', methods=['GET'])
@auth_required()
def get_balance(wallet_id=None):
from flask import request
user_id = request.args.get('user_id', cu.id)
if not is_admin() and cu.id != user_id:
abort(403)
get_wallet(user_id)
query = Wallet.query.filter(or_(
Wallet.users.any(id=user_id), Wallet.user_id == user_id
))
if wallet_id is not None:
query = query.filter_by(wallet_id=wallet_id)
product = request.args.get('product')
return jsonify({
wallet.id: {
'name': wallet.name(),
'balance': wallet.balance(product),
'main': wallet.user_id == user_id
} for wallet in query.all()
})
[docs]
@bp.route('/subscription', methods=['GET'])
@bp.route('/subscription/<int:wallet_id>', methods=['GET'])
@auth_required()
def get_subscription(wallet_id=None):
from flask import request
user_id = request.args.get('user_id', cu.id)
kw = {'id': wallet_id, 'user_id': user_id}
if not is_admin() and cu.id != user_id:
abort(403)
if wallet_id is None:
kw.pop('id', None)
return jsonify({
wallet.id: wallet.subscription()
for wallet in Wallet.query.filter_by(**kw).all()
})
[docs]
class TxnType(db.Model):
__tablename__ = 'transaction_type'
id = Column(Integer, primary_key=True)
name = Column(String(255))
[docs]
def __repr__(self):
return f'{self.name}'
[docs]
class Txn(db.Model):
__tablename__ = 'wallet_transaction'
id = Column(Integer, primary_key=True)
wallet_id = Column(Integer, db.ForeignKey('wallet.id'), nullable=False)
wallet = db.relationship('Wallet', foreign_keys=[wallet_id])
type_id = Column(
Integer, db.ForeignKey('transaction_type.id'), nullable=False
)
type = db.relationship('TxnType', foreign_keys=[type_id])
credits = Column(Integer, default=0)
product = Column(String(255))
discount = Column(Integer, default=0)
subtotal = Column(Integer, default=0)
tax = Column(Integer, default=0)
total = Column(Integer, default=0)
currency = Column(String(64))
stripe_id = Column(String(255))
raw_data = Column('raw_data', JSON)
expired_at = Column(DateTime())
valid_from = Column(
DateTime(), nullable=False, default=datetime.datetime.utcnow
)
created_at = Column(
DateTime(), nullable=False, default=datetime.datetime.utcnow
)
updated_at = Column(
DateTime(), nullable=True, onupdate=datetime.datetime.utcnow
)
created_by = db.Column(
db.Integer, db.ForeignKey('user.id'), nullable=True,
default=lambda: getattr(cu, 'id', None)
)
updated_by = db.Column(
db.Integer, db.ForeignKey('user.id'), nullable=True,
onupdate=lambda: getattr(cu, 'id', None)
)
[docs]
def __repr__(self):
return f'Transaction - {self.id}'
[docs]
def update_credits(self, credits, session=db.session, force=False):
assert credits >= 0, 'Credits update have to be positive.'
assert force or -self.credits >= credits, \
'Credits update have to be lower than previous.'
self.credits = -credits
session.add(self)
session.flush()
INF_DATE = datetime.datetime(9999, 12, 31, 23, 59)
PURCHASE = 1
REFUND = 2
USAGE = 3
CHARGE = 4
TRANSFER = 5
SUBSCRIPTION = 6
[docs]
def insert_transaction_type(target, connection, **kw):
connection.execute(target.insert(), [
{'id': PURCHASE, 'name': 'Purchase'},
{'id': REFUND, 'name': 'Refund'},
{'id': USAGE, 'name': 'Usage'},
{'id': CHARGE, 'name': 'Charge'},
{'id': TRANSFER, 'name': 'Transfer'},
{'id': SUBSCRIPTION, 'name': 'Subscription'}
])
event.listen(
TxnType.__table__, 'after_create', insert_transaction_type
)
[docs]
def compute_line_items(quantity, tiers, type='graduated', extra=None):
tiers = sorted(tiers, key=lambda x: x.get('last_unit', float('inf')))
tiers[-1] = {k: v for k, v in tiers[-1].items() if k != 'last_unit'}
line_items = []
if type == 'volume':
tier = next((
tier for tier in tiers
if quantity > tier.get('last_unit', float('inf'))
))
per_unit = tier.get('per_unit')
if per_unit:
line_items.append(sh.combine_nested_dicts(per_unit, {
'quantity': quantity, 'metadata': {'credits': quantity}
}))
if tier.get('flat_fee'):
line_items.append(sh.combine_nested_dicts(tier['flat_fee'], {
'quantity': quantity, 'metadata': {
'credits': 0 if per_unit else quantity
}
}))
else:
prev_unit = 0
for tier in tiers:
last_unit = tier.get('last_unit', float('inf'))
per_unit = tier.get('per_unit')
credits = (min(last_unit, quantity) - prev_unit)
if per_unit:
line_items.append(sh.combine_nested_dicts(per_unit, {
'quantity': credits, 'metadata': {'credits': credits}
}))
if tier.get('flat_fee'):
line_items.append(sh.combine_nested_dicts(tier['flat_fee'], {
'quantity': 1, 'metadata': {
'credits': 0 if per_unit else credits
}
}))
if quantity <= last_unit:
break
prev_unit = tier['last_unit']
if extra:
line_items = [
sh.combine_dicts(extra, v)
for v in line_items
]
return line_items
[docs]
def search_stripe_customer(api_key, user=cu):
for customer in stripe.Customer.search(
query=f"email:'{user.email}'", api_key=api_key, limit=1
).data:
if customer.metadata.user_id != str(user.id):
customer = stripe.Customer.modify(
customer.id, api_key=api_key,
metadata={"user_id": str(user.id)}
)
return customer
[docs]
def user2stripe_customer(user=cu):
from flask import current_app as ca
api_key = ca.config['STRIPE_SECRET_KEY']
key = f'Stripe-customer-{user.id}'
with Lock(key, timeout=30):
customer = ca.extensions['schedula_cache'].get(key)
if customer:
return customer
customer = search_stripe_customer(api_key=api_key, user=user)
if not customer or customer.get('deleted'):
customer = stripe.Customer.create(
api_key=api_key, email=user.email,
name=f'{user.firstname} {user.lastname}',
metadata={'user_id': str(user.id)}
)
customer = customer.id
ca.extensions['schedula_cache'].set(key, customer, timeout=60)
return customer
[docs]
def stripe_customer2user(customer):
user = User.query.get(customer.metadata['user_id'])
if user:
return user
user = User.query.filter_by(email=customer.email).first()
if user:
return user
from flask import current_app as ca
user = ca.security.datastore.create_user(
email=customer.email, firstname=customer.name
)
db.session.flush([user])
api_key = ca.config['STRIPE_SECRET_KEY']
stripe.Customer.modify(
customer.id, api_key=api_key, metadata={"user_id": str(user.id)}
)
return user
[docs]
def get_wallet(user_id, session=db.session):
with Lock(f'wallet-user-{user_id}'):
wallet = session.query(Wallet).filter_by(user_id=user_id).one_or_none()
if not wallet:
wallet = Wallet(user_id=user_id)
session.add(wallet)
session.commit()
return wallet
[docs]
@bp.route('/create-customer-pricing-table-session', methods=['POST'])
@auth_required()
def create_pricing_table():
from flask import request, current_app as ca
try:
data = request.get_json() if request.is_json else dict(request.form)
data = json_secrets.secrets(data, False)
customer = user2stripe_customer()
session = stripe.CustomerSession.create(
api_key=ca.config['STRIPE_SECRET_KEY'],
**sh.combine_nested_dicts(data, base={
"customer": customer,
"components": {"pricing_table": {"enabled": True}}
})
)
except Exception as e:
return jsonify(error=str(e))
return jsonify(clientSecret=session.client_secret)
[docs]
@bp.route('/create-customer-portal-session', methods=['POST'])
@auth_required()
def create_portal(skip_data=False):
from stripe.billing_portal import Session
from flask import request, current_app as ca, session
try:
if skip_data:
data = {}
else:
data = request.get_json() if request.is_json else dict(request.form)
data = json_secrets.secrets(data, False)
customer = user2stripe_customer()
from flask import current_app as ca
api_key = ca.config['STRIPE_SECRET_KEY']
for sub in stripe.Subscription.list(
customer=customer, api_key=api_key, status="active", limit=1
).auto_paging_iter():
plan = sub.get('items').data[0].plan
subscription = plan.nickname or plan.id
break
else:
subscription = ''
session = Session.create(
api_key=ca.config['STRIPE_SECRET_KEY'],
**sh.combine_nested_dicts(data, base={
"customer": customer,
'return_url': request.referrer,
'locale': session.get('locale', 'en_US').split('_')[0]
})
)
except Exception as e:
return jsonify(error=str(e))
return jsonify(session_url=session.url, subscription=subscription)
[docs]
def get_discounts():
discounts = {}
for k, v in sh.stack_nested_keys(get_wallet(cu.id).subscription()):
if k[-1] == 'discounts':
for product, flat, perc in json.loads(v):
f, p = discounts.get(product, (0, 1))
discounts[product] = f + flat, p * (1 - perc)
discounts = {k: list(v) for k, v in discounts.items() if v != (0, 1)}
if discounts:
from flask import current_app as ca
api_key = ca.config['STRIPE_SECRET_KEY']
price_discounts = {}
product_discounts = {}
for prod, name in (
(product, product.name)
for product in stripe.Product.list(
active=True, api_key=api_key
).auto_paging_iter() if product.name in discounts):
product_discounts[prod.id] = name
for price in stripe.Price.list(
active=True, product=prod.id, api_key=api_key
).auto_paging_iter():
price_discounts[price.id] = name
return {
'discounts': discounts,
'prod_name': {k: k for k in discounts},
'price': price_discounts,
'product': product_discounts
}
return {}
[docs]
def update_line_items_discounts(line_items, discounts):
from flask import current_app as ca
api_key = ca.config['STRIPE_SECRET_KEY']
line_items = copy.deepcopy(line_items)
for item in line_items:
if 'price' in item and item['price'] in discounts['price']:
p = stripe.Price.retrieve(item.pop('price'), api_key=api_key)
item['price_data'] = {
'currency': p.currency,
'product': p.product.id,
'recurring': p.recurring,
'tax_behavior': p.tax_behavior,
'unit_amount_decimal': p.unit_amount_decimal
}
if 'price_data' not in item:
continue
price_data = item['price_data']
if 'product' in price_data:
d = discounts['product'].get(price_data['product'])
else:
d = discounts['prod_name'].get(price_data['product_data']['name'])
if d is None:
continue
d = discounts['discounts'][d]
for k, s in (('unit_amount', 1.0), ('unit_amount_decimal', 100.0)):
if k not in price_data:
continue
if d[0]:
quantity = item['quantity']
cost = float(price_data[k]) / s * quantity
new_cost = max(cost - d[0], 0)
d[0] -= cost - new_cost
amount = new_cost / quantity * s
else:
amount = float(price_data[k])
price_data[k] = '%d' % math.ceil(amount * d[1])
return line_items
[docs]
def get_tax_rates(tax_rates):
res = []
tax_rates_list = None
for tax_rate in tax_rates:
if isinstance(tax_rate, dict):
metadata = tax_rate.get('metadata', {})
kwargs = {k: v for k, v in tax_rate.items() if k != 'metadata'}
import hashlib
from flask import current_app as ca
api_key = ca.config['STRIPE_SECRET_KEY']
hash = hashlib.sha256(json.dumps(
kwargs, sort_keys=True
).encode()).hexdigest()
if tax_rates_list is None:
tax_rates_list = list(
stripe.TaxRate.list(api_key=api_key).auto_paging_iter()
)
for tax_rate in tax_rates_list:
if tax_rate.metadata.get('hash') == hash:
tax_rate = tax_rate.id
break
else:
tax_rate = stripe.TaxRate.create(
api_key=api_key,
metadata=sh.combine_dicts(metadata, {'hash': hash}),
**kwargs
).id
res.append(tax_rate)
return res
[docs]
@bp.route('/create-checkout-session', methods=['POST'])
def create_payment():
from stripe.checkout import Session
from flask import request, current_app as ca, session
try:
data = request.get_json() if request.is_json else dict(request.form)
data = json_secrets.secrets(data, False)
customer = user2stripe_customer()
if data['mode'] == 'subscription':
from flask import current_app as ca
api_key = ca.config['STRIPE_SECRET_KEY']
for _ in stripe.Subscription.list(
customer=customer, api_key=api_key, status='active', limit=1
).auto_paging_iter():
return create_portal(True)
metadata = {f'customer_{k}': getattr(cu, k) for k in (
'id', 'firstname', 'lastname', 'email'
) if hasattr(cu, k)}
metadata.update({f'customer_{k}': json.dumps(getattr(cu, k)) for k in (
'custom_data',
) if hasattr(cu, k)})
api_key = ca.config['STRIPE_SECRET_KEY']
if 'line_items' in data:
it = data['line_items']
if not isinstance(it, list):
it = [it]
line_items = []
for d in it:
for i in ('dynamic_tax_rates', 'tax_rates'):
if i in d:
d[i] = get_tax_rates(d[i])
if 'tiers' in d:
line_items.extend(compute_line_items(
d.pop('quantity'), extra=d,
**d.pop('tiers')
))
else:
line_items.append(d)
line_items = format_line_items(line_items)
metadata['line_items'] = json.dumps([
d.pop('metadata', None) for d in line_items
])
data['line_items'] = line_items
session = Session.create(
api_key=api_key,
**sh.combine_nested_dicts(data, base={
'ui_mode': 'embedded',
'customer': customer,
'customer_update': {"address": "auto"},
'automatic_tax': {'enabled': True},
'redirect_on_completion': 'never',
'metadata': metadata,
'locale': session.get('locale', 'en_US').split('_')[0]
})
)
except Exception as e:
return jsonify(error=str(e))
return jsonify(clientSecret=session.client_secret, sessionId=session.id)
[docs]
def checkout_session_completed(session_id):
from flask import current_app as ca
from stripe.checkout import Session
with Lock(f'Txn-stripe-{session_id}'):
if db.session.query(
Txn.query.filter_by(stripe_id=session_id).exists()
).scalar():
return False
session = Session.retrieve(
session_id, api_key=ca.config['STRIPE_SECRET_KEY'],
expand=['line_items.data.price.product', 'customer']
)
if session.mode != 'payment':
return
customer = session.customer
current_time = datetime.datetime.fromtimestamp(session.created)
from asteval import Interpreter
aeval = Interpreter(usersyms={
'now': current_time,
'relativedelta': relativedelta
}, minimal=True)
user = stripe_customer2user(customer)
wallet = get_wallet(user.id)
line_items = json.loads(session.metadata.get('line_items', '[]'))
transactions = []
for i, item in enumerate(session.line_items.data):
price = item.price
product = price.product
expired_at = aeval(price.metadata.get(
'expires_at', product.metadata.get('expires_at', 'None')
))
try:
credits = line_items[i]['credits']
except (IndexError, KeyError):
credits = item.quantity
transactions.append(Txn(
wallet_id=wallet.id,
type_id=PURCHASE,
product=product.name,
subtotal=item.amount_subtotal,
discount=item.amount_discount,
tax=item.amount_tax,
total=item.amount_total,
currency=item.currency,
stripe_id=session_id,
raw_data=item.to_dict_recursive(),
created_by=user.id,
valid_from=current_time,
expired_at=expired_at,
))
transactions.append(Txn(
wallet_id=wallet.id,
type_id=CHARGE,
credits=credits,
product=product.name,
stripe_id=session_id,
created_by=user.id,
valid_from=current_time,
expired_at=expired_at,
))
db.session.add_all(transactions)
db.session.commit()
return True
[docs]
def refund_charge(stripe_id, start_time, session, type_ids=(CHARGE,)):
with Lock(f'Txn-stripe-{stripe_id}'):
base = Txn.query.filter_by(stripe_id=stripe_id).filter(or_(*(
Txn.type_id == type_id for type_id in type_ids
)))
base.filter(Txn.valid_from > start_time).delete(
synchronize_session=False
)
base.filter(or_(
Txn.expired_at == None, Txn.expired_at > start_time
)).update({"expired_at": start_time})
session.commit()
[docs]
def subscription_invoice_paid(event):
invoice = event.data.object
billing_reason = invoice.billing_reason
if billing_reason not in (
'subscription_create', 'subscription_update', 'subscription_cycle'
):
return
from flask import current_app as ca
with Lock(f'Txn-stripe-{invoice.id}'):
if db.session.query(
Txn.query.filter_by(stripe_id=invoice.id).exists()
).scalar():
return False
api_key = ca.config['STRIPE_SECRET_KEY']
subscription = stripe.Subscription.retrieve(
invoice.subscription, api_key=api_key, expand=[
'customer', 'items.data.price.product'
]
)
customer = subscription.customer
user = stripe_customer2user(customer)
wallet = get_wallet(user.id)
start_time = datetime.datetime.fromtimestamp(
subscription.current_period_start
)
end_time = datetime.datetime.fromtimestamp(
subscription.current_period_end
) + relativedelta(days=1)
if billing_reason == 'subscription_update':
latest_invoice = Txn.query.filter_by(
wallet_id=wallet.id, type_id=SUBSCRIPTION
).filter(Txn.valid_from <= start_time).order_by(
desc(Txn.valid_from)
).first()
if latest_invoice:
latest_invoice = latest_invoice.stripe_id
refund_charge(latest_invoice, start_time, db.session, (
CHARGE, SUBSCRIPTION
))
transactions = []
for item in subscription.get('items').data:
product = item.price.product
if item.object == 'subscription_item':
transactions.append(Txn(
wallet_id=wallet.id,
type_id=SUBSCRIPTION,
product=product.name,
subtotal=invoice.subtotal,
discount=sum((
v['amount']
for v in invoice.total_discount_amounts or []
), 0),
tax=invoice.tax,
total=invoice.total,
currency=invoice.currency,
stripe_id=invoice.id,
raw_data=invoice.to_dict_recursive(),
created_by=user.id,
valid_from=start_time,
expired_at=end_time
))
products = json.loads(product.metadata.get('products', '[]'))
products.extend(
json.loads(item.price.metadata.get('products', '[]'))
)
for feat in stripe.Product.list_features(
product.id, api_key=api_key
).data:
metadata = feat.entitlement_feature.metadata or {}
products.extend(json.loads(metadata.get('products', '[]')))
for name, credits, freq in products:
for valid_from, expired_at in date_range(
start_time, end_time, freq
):
transactions.append(Txn(
wallet_id=wallet.id,
type_id=CHARGE,
credits=credits,
product=name,
stripe_id=invoice.id,
created_by=user.id,
valid_from=valid_from,
expired_at=expired_at
))
db.session.add_all(transactions)
db.session.commit()
return True
[docs]
def charge_refunded(event):
from sqlalchemy.exc import NoResultFound
from flask import current_app as ca
from stripe.checkout import Session
api_key = ca.config['STRIPE_SECRET_KEY']
charge = event.data.object
amount_refunded = charge.amount_refunded
try:
stripe_id = charge.invoice
wallet_id = Txn.query.filter_by(
stripe_id=stripe_id, type_id=SUBSCRIPTION
).one().wallet_id
except NoResultFound:
try:
stripe_id = Session.list(
payment_intent=charge.payment_intent, api_key=api_key, limit=1
).data[0].id
wallet_id = Txn.query.filter_by(
stripe_id=stripe_id, type_id=PURCHASE
).first().wallet_id
except (IndexError, AttributeError):
return
current_time = datetime.datetime.fromtimestamp(event.created)
refund_charge(stripe_id, current_time, db.session)
if amount_refunded:
db.session.add(Txn(
type_id=REFUND,
stripe_id=event.id,
wallet_id=wallet_id,
total=charge.amount_refunded,
currency=charge.currency,
raw_data=charge.to_dict_recursive(),
valid_from=current_time,
))
db.session.commit()
[docs]
@bp.route('/session-status/<session_id>', methods=['GET'])
def session_status(session_id):
from flask import current_app as ca
from stripe.checkout import Session
session = Session.retrieve(
session_id, api_key=ca.config['STRIPE_SECRET_KEY']
)
status = session.status
if status == "complete":
msg = lazy_gettext('Payment succeeded!', domain='credits')
category = 'success'
checkout_session_completed(session_id)
elif status == "processing":
msg = lazy_gettext('Your payment is processing.', domain='credits')
category = 'info'
elif status == "requires_payment_method":
msg = lazy_gettext(
'Your payment was not successful, please try again.',
domain='credits'
)
category = 'warning'
else:
msg = lazy_gettext('Something went wrong.', domain='credits')
category = 'error'
flash(str(msg), category)
return jsonify(
status=status,
customer_email=session.customer_details.email,
userInfo=getattr(cu, "get_security_payload", lambda: {})()
)
[docs]
@bp.route('/webhooks', methods=['POST'])
@csrf.exempt
def stripe_webhook():
from flask import request, current_app as ca
payload = request.data
sig_header = request.headers['STRIPE_SIGNATURE']
api_key = ca.config['STRIPE_SECRET_KEY']
try:
event = stripe.Webhook.construct_event(
payload, sig_header,
ca.config['STRIPE_WEBHOOK_SECRET_KEY'],
api_key=api_key,
tolerance=None
)
except ValueError as e:
# Invalid payload
raise e
except stripe.error.SignatureVerificationError as e:
# Invalid signature
raise e
event_type = event.type
if event_type == 'checkout.session.completed':
checkout_session_completed(event.data.object.id)
elif event_type == 'charge.refunded':
charge_refunded(event)
elif event_type == 'invoice.paid':
subscription_invoice_paid(event)
ca.stripe_event_handler(event)
return jsonify(success=True)
[docs]
class Credits:
[docs]
def __init__(self, app, sitemap, *args, **kwargs):
if app is not None:
self.init_app(app, sitemap, *args, **kwargs)
[docs]
def init_app(self, app, sitemap, *args, **kwargs):
app.extensions = getattr(app, 'extensions', {})
for k in (
'STRIPE_SECRET_KEY', 'STRIPE_PUBLISHABLE_KEY',
'STRIPE_WEBHOOK_SECRET_KEY'
):
app.config[k] = app.config.get(k, os.environ.get(k))
assert app.config[k], f'`{k}` is required!'
for k, v in {
"CACHE_TYPE": "SimpleCache",
"CACHE_DEFAULT_TIMEOUT": 300
}.items():
app.config[k] = app.config.get(k, os.environ.get(k, v))
if isinstance(v, int):
app.config[k] = int(app.config[k])
app.stripe_event_handler = sitemap.stripe_event_handler
app.register_blueprint(bp, url_prefix='/stripe')
app.extensions['schedula_credits'] = self
app.extensions['schedula_cache'] = Cache(app)
import sherlock
lock_config = sherlock._configuration
try:
lock_config.client
except ValueError as ex:
if lock_config.backend is None:
sherlock.configure(backend=sherlock.backends.FILE)
else:
raise ex
if 'schedula_admin' in app.extensions:
admin = app.extensions['schedula_admin']
for v in (Wallet, Txn, TxnType):
admin.add_model(v, category="Credits")